diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 48154284..87963eda 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=2.0.999 + run: make update-package-versions VERSION=2.1.999 - name: Setup environment run: python3 -m venv env diff --git a/README.md b/README.md index 61a03cc7..d1db4ffd 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,17 @@ # The context backend for reliable AI +trustgraph-ai%2Ftrustgraph | Trendshift + +# The context backend for AI agents + -LLMs alone hallucinate and diverge from ground truth. [TrustGraph](https://trustgraph.ai) is a context system that stores, enriches, and delivers context to LLMs to enable reliable AI agents. Think like [Supabase](https://github.com/supabase/supabase) but AI-native and powered by context graphs. +Durable agent memory you can trust. Build, version, and retrieve grounded context from a context graph. + +- Give agents **memory** that persists across sessions and deployments. +- Reduce hallucinations with **grounded context retrieval** +- Ship reusable, portable [Context Cores](#context-cores) (packaged context you can move between projects/environments). The context backend: - [x] Multi-model and multimodal database system @@ -45,21 +53,6 @@ The context backend: - [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html) - [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python) - [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/) - -## No API Keys Required - -How many times have you cloned a repo and opened the `.env.example` to see the dozens of API keys for 3rd party dependencies needed to make the services work? There are only 3 things in TrustGraph that might need an API key: - -- 3rd party LLM services like Anthropic, Cohere, Gemini, Mistral, OpenAI, etc. -- 3rd party OCR like Mistral OCR -- The API key *you set* for the TrustGraph API gateway - -Everything else is included. -- [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html) -- [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant) -- [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible) -- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar) -- [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile) ## Quickstart @@ -76,8 +69,6 @@ TrustGraph downloads as Docker containers and can be run locally with Docker, Po width="80%" controls>

-For a browser based quickstart, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/). -
Table of Contents
@@ -181,24 +172,28 @@ TrustGraph provides component flexibility to optimize agent workflows.
-Multi-model storage +Graph Storage +
+ +- Apache Cassandra (default)
+- Neo4j
+- Memgraph
+- FalkorDB
+ +
+
+VectorDBs
- Apache Cassandra
-
-
-VectorDB -
- -- Qdrant
-
File and Object Storage
-- Garage
+- Garage (default)
+- MinIO
diff --git a/docs/api-gateway-changes-v1.8-to-v2.1.md b/docs/api-gateway-changes-v1.8-to-v2.1.md new file mode 100644 index 00000000..099dadb0 --- /dev/null +++ b/docs/api-gateway-changes-v1.8-to-v2.1.md @@ -0,0 +1,108 @@ +# API Gateway Changes: v1.8 to v2.1 + +## Summary + +The API gateway gained new WebSocket service dispatchers for embeddings +queries, a new REST streaming endpoint for document content, and underwent +a significant wire format change from `Value` to `Term`. The "objects" +service was renamed to "rows". + +--- + +## New WebSocket Service Dispatchers + +These are new request/response services available through the WebSocket +multiplexer at `/api/v1/socket` (flow-scoped): + +| Service Key | Description | +|-------------|-------------| +| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. | +| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. | + +These join the existing `graph-embeddings` dispatcher (which was already +present in v1.8 but may have been updated). + +### Full list of WebSocket flow service dispatchers (v2.1) + +Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or +WebSocket mux): + +- `agent`, `text-completion`, `prompt`, `mcp-tool` +- `graph-rag`, `document-rag` +- `embeddings`, `graph-embeddings`, `document-embeddings` +- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag` +- `row-embeddings` + +--- + +## New REST Endpoint + +| Method | Path | Description | +|--------|------|-------------| +| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. | + +--- + +## Renamed Service: "objects" to "rows" + +| v1.8 | v2.1 | Notes | +|------|------|-------| +| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. | +| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. | + +The WebSocket service key changed from `"objects"` to `"rows"`, and the +import dispatcher key similarly changed from `"objects"` to `"rows"`. + +--- + +## Wire Format Change: Value to Term + +The serialization layer (`serialize.py`) was rewritten to use the new `Term` +type instead of the old `Value` type. + +### Old format (v1.8 — `Value`) + +```json +{"v": "http://example.org/entity", "e": true} +``` + +- `v`: the value (string) +- `e`: boolean flag indicating whether the value is a URI + +### New format (v2.1 — `Term`) + +IRIs: +```json +{"t": "i", "i": "http://example.org/entity"} +``` + +Literals: +```json +{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"} +``` + +Quoted triples (RDF-star): +```json +{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}} +``` + +- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node) +- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives` + +### Other serialization changes + +| Field | v1.8 | v2.1 | +|-------|------|------| +| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) | +| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) | +| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) | + +--- + +## Breaking Changes + +- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format. +- **`objects` to `rows` rename**: WebSocket service key and import key changed. +- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value). +- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text. +- **New `/api/v1/document-stream` endpoint**: Additive, not breaking. diff --git a/docs/api.html b/docs/api.html index 201771ec..7cbddd32 100644 --- a/docs/api.html +++ b/docs/api.html @@ -12,413 +12,417 @@ margin: 0; } - -

TrustGraph API Gateway (1.8)

Download OpenAPI specification:

REST API for TrustGraph - an AI-powered knowledge graph and RAG system.

-

Overview

TrustGraph API Gateway (2.1)

Download OpenAPI specification:

REST API for TrustGraph - an AI-powered knowledge graph and RAG system.

+

Overview

Import/Export: Bulk data operations for triples, embeddings, entity contexts
  • WebSocket: Multiplexed interface for all services
  • -

    Service Types

    Service Types

  • AI services: agent, text-completion, prompt, RAG (document/graph)
  • Embeddings: embeddings, graph-embeddings, document-embeddings
  • -
  • Query: triples, objects, nlp-query, structured-query
  • +
  • Query: triples, rows, nlp-query, structured-query, row-embeddings
  • Data loading: text-load, document-load
  • Utilities: mcp-tool, structured-diag
  • -

    Authentication

    Authentication

    : Bearer <token>

    If GATEWAY_SECRET is not set, API runs without authentication (development mode).

    -

    Field Naming

    Field Naming

    All JSON fields use kebab-case: flow-id, blueprint-name, doc-limit, etc.

    -

    Error Responses

    Error Responses

    } } -

    Config

    Config

    Configuration management (global service)

    -

    Configuration service

    Configuration service

    Manage TrustGraph configuration including flows, prompts, token costs, parameter types, and more.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Manage TrustGraph configuration including flows, prompts, token costs, parameter types, and more.

    Operations

    config

    Get the complete system configuration including all flows, prompts, token costs, etc.

    @@ -576,7 +580,7 @@ The flow service (/api/v1/flow) manages runnin
  • Use config service to store/retrieve flow definitions
  • Use flow service to start/stop/manage running flows
  • -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "config" "list" "get" "put" "delete"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "config" "list" "get" "put" "delete"

    Operation to perform:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Operation to perform:

    • config: Get complete configuration
    • list: List all items of a specific type
    • @@ -592,25 +596,25 @@ The flow service (/api/v1/flow) manages runnin
    • put: Set/update configuration values
    • delete: Delete configuration items
    -
    type
    string
    type
    string

    Configuration type (required for list, get, put, delete operations). +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Configuration type (required for list, get, put, delete operations). Common types: flow, prompt, token-cost, parameter-type, interface-description

    -
    Array of objects

    Keys to retrieve (for get operation) or delete (for delete operation)

    -
    Array of objects

    Values to set/update (for put operation)

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "config"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "version": 42,
    • "config": {
      }
    }

    Flow

    Array of objects

    Keys to retrieve (for get operation) or delete (for delete operation)

    +
    Array of objects

    Values to set/update (for put operation)

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "config"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "version": 42,
    • "config": {
      }
    }

    Flow

    Flow lifecycle and blueprint management (global service)

    -

    Flow lifecycle and blueprint management

    Flow lifecycle and blueprint management

    Manage flow instances and blueprints.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Manage flow instances and blueprints.

    Important Distinction

    The flow service manages running flow instances. The config service (/api/v1/config) manages stored configuration.

    @@ -688,7 +692,7 @@ The config service (/api/v1/config) manages st

    delete-blueprint

    Delete a custom blueprint definition. Built-in blueprints cannot be deleted.

    -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "start-flow" "stop-flow" "list-flows" "get-flow" "list-blueprints" "get-blueprint" "put-blueprint" "delete-blueprint"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "start-flow" "stop-flow" "list-flows" "get-flow" "list-blueprints" "get-blueprint" "put-blueprint" "delete-blueprint"

    Flow operation:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Flow operation:

    • start-flow: Start a new flow instance from a blueprint
    • stop-flow: Stop a running flow instance
    • @@ -710,29 +714,29 @@ The config service (/api/v1/config) manages st
    • put-blueprint: Create/update blueprint definition
    • delete-blueprint: Delete blueprint definition
    -
    flow-id
    string

    Flow instance ID (required for start-flow, stop-flow, get-flow)

    -
    blueprint-name
    string

    Flow blueprint name (required for start-flow, get-blueprint, put-blueprint, delete-blueprint)

    -
    object

    Flow blueprint definition (required for put-blueprint)

    -
    description
    string

    Flow description (optional for start-flow)

    -
    object
    flow-id
    string

    Flow instance ID (required for start-flow, stop-flow, get-flow)

    +
    blueprint-name
    string

    Flow blueprint name (required for start-flow, get-blueprint, put-blueprint, delete-blueprint)

    +
    object

    Flow blueprint definition (required for put-blueprint)

    +
    description
    string

    Flow description (optional for start-flow)

    +
    object

    Flow parameters (for start-flow). +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Flow parameters (for start-flow). All values are stored as strings, regardless of input type.

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "start-flow",
    • "flow-id": "my-flow",
    • "blueprint-name": "document-rag",
    • "description": "My document processing flow",
    • "parameters": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "flow-id": "my-flow"
    }

    Librarian

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "start-flow",
    • "flow-id": "my-flow",
    • "blueprint-name": "document-rag",
    • "description": "My document processing flow",
    • "parameters": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "flow-id": "my-flow"
    }

    Librarian

    Document library management (global service)

    -

    Document library management

    Document library management

    Manage document library: add, remove, list documents, and control processing.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Manage document library: add, remove, list documents, and control processing.

    Document Library

    The librarian service manages a persistent library of documents that can be:

      @@ -780,7 +784,7 @@ for processing and handled asynchronously.

      Stop ongoing library document processing.

      list-processing

      List current processing tasks and their status.

      -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "add-document" "remove-document" "list-documents" "start-processing" "stop-processing" "list-processing"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "add-document" "remove-document" "list-documents" "start-processing" "stop-processing" "list-processing"

    Library operation:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Library operation:

    • add-document: Add document to library
    • remove-document: Remove document from library
    • @@ -798,35 +802,35 @@ for processing and handled asynchronously.

    • stop-processing: Stop library processing
    • list-processing: List processing status
    -
    flow
    string

    Flow ID

    -
    collection
    string
    Default: "default"

    Collection identifier

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    document-id
    string

    Document identifier

    -
    processing-id
    string

    Processing task identifier

    -
    object (DocumentMetadata)

    Document metadata for library management

    -
    object (ProcessingMetadata)

    Processing metadata for library document processing

    -
    content
    string

    Document content (for add-document with inline content)

    -
    Array of objects

    Search criteria for filtering documents

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "add-document",
    • "flow": "my-flow",
    • "collection": "default",
    • "document-metadata": {}
    }

    Response samples

    Content type
    application/json
    Example
    {}

    Knowledge

    flow
    string

    Flow ID

    +
    collection
    string
    Default: "default"

    Collection identifier

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    document-id
    string

    Document identifier

    +
    processing-id
    string

    Processing task identifier

    +
    object (DocumentMetadata)

    Document metadata for library management

    +
    object (ProcessingMetadata)

    Processing metadata for library document processing

    +
    content
    string

    Document content (for add-document with inline content)

    +
    Array of objects

    Search criteria for filtering documents

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "add-document",
    • "flow": "my-flow",
    • "collection": "default",
    • "document-metadata": {}
    }

    Response samples

    Content type
    application/json
    Example
    {}

    Knowledge

    Knowledge graph core management (global service)

    -

    Knowledge graph core management

    Knowledge graph core management

    Manage knowledge graph cores - persistent storage of triples and embeddings.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Manage knowledge graph cores - persistent storage of triples and embeddings.

    Knowledge Cores

    Knowledge cores are the foundational storage units for:

      @@ -890,7 +894,7 @@ Removes data from flow instance but doesn't delete the core.

    • Multiple messages with triples or graph-embeddings
    • Final message with eos: true to signal completion
    • -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "list-kg-cores" "get-kg-core" "put-kg-core" "delete-kg-core" "load-kg-core" "unload-kg-core"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "list-kg-cores" "get-kg-core" "put-kg-core" "delete-kg-core" "load-kg-core" "unload-kg-core"

    Knowledge core operation:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Knowledge core operation:

    • list-kg-cores: List knowledge cores for user
    • get-kg-core: Get knowledge core by ID
    • @@ -908,29 +912,29 @@ Removes data from flow instance but doesn't delete the core.

    • load-kg-core: Load knowledge core into flow
    • unload-kg-core: Unload knowledge core from flow
    -
    user
    string
    Default: "trustgraph"

    User identifier (for list-kg-cores, put-kg-core, delete-kg-core)

    -
    id
    string

    Knowledge core ID (for get, put, delete, load, unload)

    -
    flow
    string

    Flow ID (for load-kg-core)

    -
    collection
    string
    Default: "default"

    Collection identifier (for load-kg-core)

    -
    object

    Triples to store (for put-kg-core)

    -
    object

    Graph embeddings to store (for put-kg-core)

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "list-kg-cores",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "ids": [
      ]
    }

    Collection

    user
    string
    Default: "trustgraph"

    User identifier (for list-kg-cores, put-kg-core, delete-kg-core)

    +
    id
    string

    Knowledge core ID (for get, put, delete, load, unload)

    +
    flow
    string

    Flow ID (for load-kg-core)

    +
    collection
    string
    Default: "default"

    Collection identifier (for load-kg-core)

    +
    object

    Triples to store (for put-kg-core)

    +
    object

    Graph embeddings to store (for put-kg-core)

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "list-kg-cores",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "ids": [
      ]
    }

    Collection

    Collection metadata management (global service)

    -

    Collection metadata management

    Collection metadata management

    Manage collection metadata for organizing documents and knowledge.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Manage collection metadata for organizing documents and knowledge.

    Collections

    Collections are organizational units for grouping:

      @@ -982,45 +986,45 @@ If it exists, metadata is updated. Allows setting name, description, and tags.delete-collection

      Delete a collection by user and collection ID. This removes the metadata but typically does not delete the associated data (documents, knowledge cores).

      -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "list-collections" "update-collection" "delete-collection"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "list-collections" "update-collection" "delete-collection"

    Collection operation:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Collection operation:

    • list-collections: List collections for user
    • update-collection: Create or update collection metadata
    • delete-collection: Delete collection
    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string

    Collection identifier (for update, delete)

    -
    timestamp
    string <date-time>

    ISO timestamp

    -
    name
    string

    Human-readable collection name (for update)

    -
    description
    string

    Collection description (for update)

    -
    tags
    Array of strings

    Collection tags for organization (for update)

    -
    tag-filter
    Array of strings

    Filter collections by tags (for list)

    -
    limit
    integer
    Default: 0

    Maximum number of results (for list)

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "list-collections",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "timestamp": "2024-01-15T10:30:00Z",
    • "collections": [
      ]
    }

    Flow Services

    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string

    Collection identifier (for update, delete)

    +
    timestamp
    string <date-time>

    ISO timestamp

    +
    name
    string

    Human-readable collection name (for update)

    +
    description
    string

    Collection description (for update)

    +
    tags
    Array of strings

    Collection tags for organization (for update)

    +
    tag-filter
    Array of strings

    Filter collections by tags (for list)

    +
    limit
    integer
    Default: 0

    Maximum number of results (for list)

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "list-collections",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "timestamp": "2024-01-15T10:30:00Z",
    • "collections": [
      ]
    }

    Flow Services

    Services hosted within flow instances

    -

    Agent service - conversational AI with reasoning

    Agent service - conversational AI with reasoning

    AI agent that can understand questions, reason about them, and take actions.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    AI agent that can understand questions, reason about them, and take actions.

    Agent Overview

    The agent service provides a conversational AI that:

      @@ -1104,29 +1108,29 @@ Each step has: thought, action, arguments, observation.</p>

      Multi-turn Conversations

      Include history array with previous steps to maintain context. Each step has: thought, action, arguments, observation.

      -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    question
    required
    string

    User question or prompt for the agent

    -
    state
    string

    Agent state for continuation (optional, for multi-turn)

    -
    group
    Array of strings

    Group identifiers for collaborative agents (optional)

    -
    Array of objects

    Conversation history (optional, list of previous agent steps)

    -
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    -
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "What is the capital of France?",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "chunk-type": "thought",
    • "content": "I need to search for information about quantum computing",
    • "end-of-message": false,
    • "end-of-dialog": false
    }

    Document RAG - retrieve and generate from documents

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    question
    required
    string

    User question or prompt for the agent

    +
    state
    string

    Agent state for continuation (optional, for multi-turn)

    +
    group
    Array of strings

    Group identifiers for collaborative agents (optional)

    +
    Array of objects

    Conversation history (optional, list of previous agent steps)

    +
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "What is the capital of France?",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "chunk-type": "thought",
    • "content": "I need to search for information about quantum computing",
    • "end-of-message": false,
    • "end-of-dialog": false
    }

    Document RAG - retrieve and generate from documents

    Retrieval-Augmented Generation over document embeddings.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Retrieval-Augmented Generation over document embeddings.

    Document RAG Overview

    Document RAG combines:

      @@ -1192,27 +1196,27 @@ Each step has: thought, action, arguments, observation.

    1. collection: Target specific document collection
    2. user: Multi-tenant isolation
    3. -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    query
    required
    string

    User query or question

    -
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    -
    collection
    string
    Default: "default"

    Collection to search within

    -
    doc-limit
    integer [ 1 .. 100 ]
    Default: 20

    Maximum number of documents to retrieve

    -
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What are the key findings in the research papers?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "The research papers present three key findings:\n1. Quantum entanglement exhibits non-local correlations\n2. Bell's inequality is violated in experimental tests\n3. Applications in quantum cryptography are promising\n",
    • "end-of-stream": false
    }

    Graph RAG - retrieve and generate from knowledge graph

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    query
    required
    string

    User query or question

    +
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    +
    collection
    string
    Default: "default"

    Collection to search within

    +
    doc-limit
    integer [ 1 .. 100 ]
    Default: 20

    Maximum number of documents to retrieve

    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What are the key findings in the research papers?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "The research papers present three key findings:\n1. Quantum entanglement exhibits non-local correlations\n2. Bell's inequality is violated in experimental tests\n3. Applications in quantum cryptography are promising\n",
    • "end-of-stream": false
    }

    Graph RAG - retrieve and generate from knowledge graph

    Retrieval-Augmented Generation over knowledge graph.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Retrieval-Augmented Generation over knowledge graph.

    Graph RAG Overview

    Graph RAG combines:

      @@ -1302,33 +1306,33 @@ Each step has: thought, action, arguments, observation.

    1. Multi-hop reasoning ("What's the path from A to B?")
    2. Structural analysis ("What are the main entities related to X?")
    3. -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    query
    required
    string

    User query or question

    -
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    -
    collection
    string
    Default: "default"

    Collection to search within

    -
    entity-limit
    integer [ 1 .. 200 ]
    Default: 50

    Maximum number of entities to retrieve

    -
    triple-limit
    integer [ 1 .. 100 ]
    Default: 30

    Maximum number of triples to retrieve per entity

    -
    max-subgraph-size
    integer [ 10 .. 5000 ]
    Default: 1000

    Maximum total subgraph size (triples)

    -
    max-path-length
    integer [ 1 .. 5 ]
    Default: 2

    Maximum path length for graph traversal

    -
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What connections exist between quantum physics and computer science?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Quantum physics and computer science intersect primarily through quantum computing.\nThe knowledge graph shows connections through:\n- Quantum algorithms (Shor's algorithm, Grover's algorithm)\n- Quantum information theory\n- Computational complexity theory\n",
    • "end-of-stream": false
    }

    Text completion - direct LLM generation

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    query
    required
    string

    User query or question

    +
    user
    string
    Default: "trustgraph"

    User identifier for multi-tenancy

    +
    collection
    string
    Default: "default"

    Collection to search within

    +
    entity-limit
    integer [ 1 .. 200 ]
    Default: 50

    Maximum number of entities to retrieve

    +
    triple-limit
    integer [ 1 .. 100 ]
    Default: 30

    Maximum number of triples to retrieve per entity

    +
    max-subgraph-size
    integer [ 10 .. 5000 ]
    Default: 1000

    Maximum total subgraph size (triples)

    +
    max-path-length
    integer [ 1 .. 5 ]
    Default: 2

    Maximum path length for graph traversal

    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What connections exist between quantum physics and computer science?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Quantum physics and computer science intersect primarily through quantum computing.\nThe knowledge graph shows connections through:\n- Quantum algorithms (Shor's algorithm, Grover's algorithm)\n- Quantum information theory\n- Computational complexity theory\n",
    • "end-of-stream": false
    }

    Text completion - direct LLM generation

    Direct text completion using LLM without retrieval augmentation.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Direct text completion using LLM without retrieval augmentation.

    Text Completion Overview

    Pure LLM generation for:

      @@ -1422,23 +1426,23 @@ Each step has: thought, action, arguments, observation.

    • Want to leverage knowledge graph relationships
    • Require citations or provenance
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    system
    required
    string

    System prompt that sets behavior and context for the LLM

    -
    prompt
    required
    string

    User prompt or question

    -
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "system": "You are a helpful assistant that provides concise answers.",
    • "prompt": "Explain the concept of recursion in programming."
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Recursion is a programming technique where a function calls itself\nto solve a problem by breaking it down into smaller, similar subproblems.\nEach recursive call works on a simpler version until reaching a base case.\n",
    • "in-token": 45,
    • "out-token": 128,
    • "model": "gpt-4",
    • "end-of-stream": false
    }

    Prompt service - template-based generation

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    system
    required
    string

    System prompt that sets behavior and context for the LLM

    +
    prompt
    required
    string

    User prompt or question

    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "system": "You are a helpful assistant that provides concise answers.",
    • "prompt": "Explain the concept of recursion in programming."
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Recursion is a programming technique where a function calls itself\nto solve a problem by breaking it down into smaller, similar subproblems.\nEach recursive call works on a simpler version until reaching a base case.\n",
    • "in-token": 45,
    • "out-token": 128,
    • "model": "gpt-4",
    • "end-of-stream": false
    }

    Prompt service - template-based generation

    Execute stored prompt templates with variable substitution.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Execute stored prompt templates with variable substitution.

    Prompt Service Overview

    The prompt service enables:

      @@ -1556,25 +1560,25 @@ Each step has: thought, action, arguments, observation.

    • Data transformation
    • Any repeatable LLM task with consistent prompting
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    id
    required
    string

    Prompt template ID (stored in config)

    -
    object

    Template variables as key-value pairs (values are JSON strings)

    -
    object

    Alternative to terms - variables as native JSON values (auto-converted)

    -
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "id": "summarize-document",
    • "terms": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "text": "This document provides an overview of quantum computing fundamentals and cryptographic applications.",
    • "end-of-stream": false
    }

    Embeddings - text to vector conversion

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    id
    required
    string

    Prompt template ID (stored in config)

    +
    object

    Template variables as key-value pairs (values are JSON strings)

    +
    object

    Alternative to terms - variables as native JSON values (auto-converted)

    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "id": "summarize-document",
    • "terms": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "text": "This document provides an overview of quantum computing fundamentals and cryptographic applications.",
    • "end-of-stream": false
    }

    Embeddings - text to vector conversion

    Convert text to embedding vectors for semantic similarity search.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Convert text to embedding vectors for semantic similarity search.

    Embeddings Overview

    Embeddings transform text into dense vector representations that:

      @@ -1630,19 +1634,19 @@ For bulk operations, use document-load or text-load services.</p>

      Single Request

      Unlike batch embedding APIs, this endpoint processes one text at a time. For bulk operations, use document-load or text-load services.

      -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    text
    required
    string

    Text to convert to embedding vector

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "Machine learning"
    }

    Response samples

    Content type
    application/json
    {
    • "vectors": [
      ]
    }

    MCP Tool - execute Model Context Protocol tools

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    text
    required
    string

    Text to convert to embedding vector

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "Machine learning"
    }

    Response samples

    Content type
    application/json
    {
    • "vectors": [
      ]
    }

    MCP Tool - execute Model Context Protocol tools

    Execute MCP (Model Context Protocol) tools for agent capabilities.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Execute MCP (Model Context Protocol) tools for agent capabilities.

    MCP Tool Overview

    MCP tools provide agent capabilities through standardized protocol:

      @@ -1734,21 +1738,21 @@ For bulk operations, use document-load or text-load services.

    • File operations: Read/write files
    • Code execution: Run scripts
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    name
    required
    string

    Tool name to execute

    -
    object

    Tool parameters (JSON object, auto-converted to string internally)

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "name": "search",
    • "parameters": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "text": "The result is 309"
    }

    Triples query - pattern-based graph queries

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    name
    required
    string

    Tool name to execute

    +
    object

    Tool parameters (JSON object, auto-converted to string internally)

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "name": "search",
    • "parameters": {
      }
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "text": "The result is 309"
    }

    Triples query - pattern-based graph queries

    Query knowledge graph using subject-predicate-object patterns.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Query knowledge graph using subject-predicate-object patterns.

    Triples Query Overview

    Query RDF triples with flexible pattern matching:

      @@ -1840,38 +1844,54 @@ For bulk operations, use document-load or text-load services.

    • More specific patterns = faster queries
    • Consider limit for large result sets
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    object (RdfValue)

    Subject filter (optional)

    -
    object (RdfValue)

    Predicate filter (optional)

    -
    object (RdfValue)

    Object filter (optional)

    -
    limit
    integer [ 1 .. 100000 ]
    Default: 10000

    Maximum number of triples to return

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection to query

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {}

    Response samples

    Content type
    application/json
    {}

    Objects query - GraphQL over knowledge graph

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    object (RdfValue)

    Subject filter (optional)

    +
    object (RdfValue)

    Predicate filter (optional)

    +
    object (RdfValue)

    Object filter (optional)

    +
    limit
    integer [ 1 .. 100000 ]
    Default: 10000

    Maximum number of triples to return

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to query

    +
    g
    string

    Named graph filter (optional).

    +
      +
    • Omitted/null: all graphs
    • +
    • Empty string: default graph only
    • +
    • URI string: specific named graph (e.g., urn:graph:source, urn:graph:retrieval)
    • +
    +
    streaming
    boolean
    Default: false

    Enable streaming response delivery

    +
    batch-size
    integer [ 1 .. 1000 ]
    Default: 20

    Number of triples per streaming batch

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {}

    Response samples

    Content type
    application/json
    {}

    Rows query - GraphQL over structured data

    Query knowledge graph using GraphQL for object-oriented data access.

    -

    Objects Query Overview

    -

    GraphQL interface to knowledge graph:

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Query structured data using GraphQL for row-oriented data access.

    +

    Rows Query Overview

    +

    GraphQL interface to structured data:

    • Schema-driven: Predefined types and relationships
    • Flexible queries: Request exactly what you need
    • Nested data: Traverse relationships in single query
    • Type-safe: Strong typing with introspection
    -

    Abstracts RDF triples into familiar object model.

    +

    Abstracts structured rows into familiar object model.

    GraphQL Benefits

    Compared to triples query:

      @@ -1956,27 +1976,27 @@ Use introspection query to discover schema.</p>

      Schema Definition

      Schema defines available types via config service. Use introspection query to discover schema.

      -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    query
    required
    string

    GraphQL query string

    -
    object

    GraphQL query variables

    -
    operation-name
    string

    Operation name (for multi-operation documents)

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection to query

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "{\n person(id: \"https://example.com/person/alice\") {\n name\n email\n }\n}\n",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "data": {
      },
    • "extensions": {
      }
    }

    NLP Query - natural language to structured query

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    query
    required
    string

    GraphQL query string

    +
    object

    GraphQL query variables

    +
    operation-name
    string

    Operation name (for multi-operation documents)

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to query

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "{\n person(id: \"https://example.com/person/alice\") {\n name\n email\n }\n}\n",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "data": {
      },
    • "extensions": {
      }
    }

    NLP Query - natural language to structured query

    Convert natural language questions to structured GraphQL queries.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Convert natural language questions to structured GraphQL queries.

    NLP Query Overview

    Transforms user questions into executable GraphQL:

      @@ -2056,7 +2076,7 @@ Use introspection query to discover schema.

      Example workflow:

      1. User asks: "Who does Alice know?"
       2. NLP Query generates GraphQL
      -3. Execute via /api/v1/flow/{flow}/service/objects
      +3. Execute via /api/v1/flow/{flow}/service/rows
       4. Return results to user
       

      Schema Detection

      @@ -2080,26 +2100,26 @@ Use introspection query to discover schema.

    • Missing schema coverage
    • Complex query structure
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    question
    required
    string

    Natural language question

    -
    max-results
    integer [ 1 .. 10000 ]
    Default: 100

    Maximum results to return when query is executed

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "Who does Alice know?",
    • "max-results": 50
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "graphql-query": "query GetConnections($person: ID!) {\n person(id: $person) {\n knows { name email }\n }\n}\n",
    • "variables": {},
    • "detected-schemas": [
      ],
    • "confidence": 0.92
    }

    Structured Query - question to results (all-in-one)

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    question
    required
    string

    Natural language question

    +
    max-results
    integer [ 1 .. 10000 ]
    Default: 100

    Maximum results to return when query is executed

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "Who does Alice know?",
    • "max-results": 50
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "graphql-query": "query GetConnections($person: ID!) {\n person(id: $person) {\n knows { name email }\n }\n}\n",
    • "variables": {},
    • "detected-schemas": [
      ],
    • "confidence": 0.92
    }

    Structured Query - question to results (all-in-one)

    Ask natural language questions and get results directly.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Ask natural language questions and get results directly.

    Structured Query Overview

    Combines two operations in one call:

    1. NLP Query: Generate GraphQL from question
    2. -
    3. Objects Query: Execute generated query
    4. +
    5. Rows Query: Execute generated query
    6. Return Results: Direct answer data

    Simplest way to query knowledge graph with natural language.

    @@ -2170,7 +2190,7 @@ Use introspection query to discover schema.

  • Output: Query results (data)
  • Use when: Want simple, direct answers
  • -

    NLP Query + Objects Query (separate calls)

    +

    NLP Query + Rows Query (separate calls)

    • Step 1: Convert question → GraphQL
    • Step 2: Execute GraphQL → results
    • @@ -2214,23 +2234,23 @@ Use introspection query to discover schema.

    • Less control: Can't inspect/modify generated query
    • Simpler code: No need to handle intermediate steps
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    question
    required
    string

    Natural language question

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection to query

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "Who does Alice know?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "data": {
      },
    • "errors": [ ]
    }

    Structured Diag - analyze structured data formats

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    question
    required
    string

    Natural language question

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to query

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "Who does Alice know?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "data": {
      },
    • "errors": [ ]
    }

    Structured Diag - analyze structured data formats

    Analyze and understand structured data (CSV, JSON, XML).

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Analyze and understand structured data (CSV, JSON, XML).

    Structured Diag Overview

    Helps process unknown structured data:

      @@ -2352,39 +2372,39 @@ Use introspection query to discover schema.

    • Use descriptor to process full dataset
    • Load data via document-load or text-load
    • -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "detect-type" "generate-descriptor" "diagnose" "schema-selection"
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "detect-type" "generate-descriptor" "diagnose" "schema-selection"

    Diagnosis operation:

    +" class="sc-iKGpAq sc-cCYyou dXXcln cFvDiF">

    Diagnosis operation:

    • detect-type: Identify data format (CSV, JSON, XML)
    • generate-descriptor: Create schema descriptor for data
    • diagnose: Full analysis (detect + generate descriptor)
    • schema-selection: Find matching schemas for data
    -
    sample
    required
    string

    Data sample to analyze (text content)

    -
    type
    string
    Enum: "csv" "json" "xml"

    Data type (required for generate-descriptor)

    -
    schema-name
    string

    Target schema name for descriptor generation (optional)

    -
    object

    Format-specific options (e.g., CSV delimiter)

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "detect-type",
    • "sample": "name,age,email\nAlice,30,alice@example.com\nBob,25,bob@example.com\n"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "operation": "detect-type",
    • "detected-type": "csv",
    • "confidence": 0.95
    }

    Graph Embeddings Query - find similar entities

    sample
    required
    string

    Data sample to analyze (text content)

    +
    type
    string
    Enum: "csv" "json" "xml"

    Data type (required for generate-descriptor)

    +
    schema-name
    string

    Target schema name for descriptor generation (optional)

    +
    object

    Format-specific options (e.g., CSV delimiter)

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "operation": "detect-type",
    • "sample": "name,age,email\nAlice,30,alice@example.com\nBob,25,bob@example.com\n"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "operation": "detect-type",
    • "detected-type": "csv",
    • "confidence": 0.95
    }

    Graph Embeddings Query - find similar entities

    Query graph embeddings to find similar entities by vector similarity.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Query graph embeddings to find similar entities by vector similarity.

    Graph Embeddings Query Overview

    Find entities semantically similar to a query vector:

      @@ -2460,25 +2480,25 @@ Use introspection query to discover schema.

    • These are references to knowledge graph entities
    • Use with triples query to get entity details
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    vectors
    required
    Array of numbers

    Query embedding vector

    -
    limit
    integer [ 1 .. 1000 ]
    Default: 10

    Maximum number of entities to return

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection to search

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "vectors": [
      ],
    • "limit": 10,
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json

    Document Embeddings Query - find similar text chunks

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    vectors
    required
    Array of numbers

    Query embedding vector

    +
    limit
    integer [ 1 .. 1000 ]
    Default: 10

    Maximum number of entities to return

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to search

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "vectors": [
      ],
    • "limit": 10,
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json

    Document Embeddings Query - find similar text chunks

    Query document embeddings to find similar text chunks by vector similarity.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Query document embeddings to find similar text chunks by vector similarity.

    Document Embeddings Query Overview

    Find document chunks semantically similar to a query vector:

      @@ -2570,25 +2590,113 @@ Use introspection query to discover schema.

    • No metadata (source, position, etc.)
    • Use for LLM context directly
    -
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    vectors
    required
    Array of numbers

    Query embedding vector

    -
    limit
    integer [ 1 .. 1000 ]
    Default: 10

    Maximum number of document chunks to return

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection to search

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "vectors": [
      ],
    • "limit": 10,
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    {
    • "chunks": [
      ]
    }

    Text Load - load text documents

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    vectors
    required
    Array of numbers

    Query embedding vector

    +
    limit
    integer [ 1 .. 1000 ]
    Default: 10

    Maximum number of document chunks to return

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to search

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "vectors": [
      ],
    • "limit": 10,
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    {
    • "chunks": [
      ]
    }

    Row Embeddings Query - semantic search on structured data

    Query row embeddings to find similar rows by vector similarity on indexed fields. +Enables fuzzy/semantic matching on structured data.

    +

    Row Embeddings Query Overview

    +

    Find rows whose indexed field values are semantically similar to a query:

    +
      +
    • Input: Query embedding vector, schema name, optional index filter
    • +
    • Search: Compare against stored row index embeddings
    • +
    • Output: Matching rows with index values and similarity scores
    • +
    +

    Core component of semantic search on structured data.

    +

    Use Cases

    +
      +
    • Fuzzy name matching: Find customers by approximate name
    • +
    • Semantic field search: Find products by description similarity
    • +
    • Data deduplication: Identify potential duplicate records
    • +
    • Entity resolution: Match records across datasets
    • +
    +

    Process

    +
      +
    1. Obtain query embedding (via embeddings service)
    2. +
    3. Query stored row index embeddings for the specified schema
    4. +
    5. Calculate cosine similarity
    6. +
    7. Return top N most similar index entries
    8. +
    9. Use index values to retrieve full rows via GraphQL
    10. +
    +

    Response Format

    +

    Each match includes:

    +
      +
    • index_name: The indexed field(s) that matched
    • +
    • index_value: The actual values for those fields
    • +
    • text: The text that was embedded
    • +
    • score: Similarity score (higher = more similar)
    • +
    +
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    vectors
    required
    Array of numbers

    Query embedding vector

    +
    schema_name
    required
    string

    Schema name to search within

    +
    index_name
    string

    Optional index name to filter search to specific index

    +
    limit
    integer [ 1 .. 1000 ]
    Default: 10

    Maximum number of matches to return

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection to search

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "vectors": [
      ],
    • "schema_name": "customers",
    • "limit": 10,
    • "user": "alice",
    • "collection": "sales"
    }

    Response samples

    Content type
    application/json
    {
    • "matches": [
      ]
    }

    Text Load - load text documents

    Load text documents into processing pipeline for indexing and embedding.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Load text documents into processing pipeline for indexing and embedding.

    Text Load Overview

    Fire-and-forget document loading:

      @@ -2682,29 +2790,29 @@ encoded = base64
      Authorizations:
      bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    text
    required
    string <byte>

    Text content (base64 encoded)

    -
    id
    string

    Document identifier

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection for document

    -
    charset
    string
    Default: "utf-8"

    Text character encoding

    -
    Array of objects (Triple)

    Document metadata as RDF triples

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==",
    • "id": "doc-123",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Document Load - load binary documents (PDF, etc.)

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    text
    required
    string <byte>

    Text content (base64 encoded)

    +
    id
    string

    Document identifier

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection for document

    +
    charset
    string
    Default: "utf-8"

    Text character encoding

    +
    Array of objects (Triple)

    Document metadata as RDF triples

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==",
    • "id": "doc-123",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Document Load - load binary documents (PDF, etc.)

    Load binary documents (PDF, Word, etc.) into processing pipeline.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Load binary documents (PDF, Word, etc.) into processing pipeline.

    Document Load Overview

    Fire-and-forget binary document loading:

      @@ -2814,29 +2922,61 @@ encoded = base64
      Authorizations:
      bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    data
    required
    string <byte>

    Document data (base64 encoded)

    -
    id
    string

    Document identifier

    -
    user
    string
    Default: "trustgraph"

    User identifier

    -
    collection
    string
    Default: "default"

    Collection for document

    -
    Array of objects (Triple)

    Document metadata as RDF triples

    -

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "data": "JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg==",
    • "id": "doc-789",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Import/Export

    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    data
    required
    string <byte>

    Document data (base64 encoded)

    +
    id
    string

    Document identifier

    +
    user
    string
    Default: "trustgraph"

    User identifier

    +
    collection
    string
    Default: "default"

    Collection for document

    +
    Array of objects (Triple)

    Document metadata as RDF triples

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "data": "JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg==",
    • "id": "doc-789",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Import/Export

    Bulk data import and export

    -

    Import Core - bulk import triples and embeddings

    Stream document content from library

    Streams the raw content of a document stored in the library. +Returns the document content in chunked transfer encoding.

    +

    Parameters

    +
      +
    • user: User identifier (required)
    • +
    • document-id: Document IRI to retrieve (required)
    • +
    • chunk-size: Size of each response chunk in bytes (optional, default: 1MB)
    • +
    +
    Authorizations:
    bearerAuth
    query Parameters
    user
    required
    string
    Example: user=trustgraph

    User identifier

    +
    document-id
    required
    string
    Example: document-id=urn:trustgraph:doc:abc123

    Document IRI to retrieve

    +
    chunk-size
    integer
    Default: 1048576

    Chunk size in bytes (default 1MB)

    +

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    Import Core - bulk import triples and embeddings

    Import knowledge cores in bulk using streaming MessagePack format.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Import knowledge cores in bulk using streaming MessagePack format.

    Import Core Overview

    Bulk data import for knowledge graph:

      @@ -2944,21 +3084,21 @@ No response body - returns 202 Accepted.

    • Bulk loading: Initial knowledge base population
    • Replication: Copy knowledge cores
    -
    Authorizations:
    bearerAuth
    query Parameters
    id
    required
    string
    Example: id=core-123

    Knowledge core ID to import

    -
    user
    required
    string
    Example: user=alice

    User identifier

    -
    Request Body schema: application/msgpack
    required
    string <binary>

    MessagePack stream of knowledge data

    -

    Responses

    Response samples

    Content type
    application/json
    { }

    Export Core - bulk export triples and embeddings

    Authorizations:
    bearerAuth
    query Parameters
    id
    required
    string
    Example: id=core-123

    Knowledge core ID to import

    +
    user
    required
    string
    Example: user=alice

    User identifier

    +
    Request Body schema: application/msgpack
    required
    string <binary>

    MessagePack stream of knowledge data

    +

    Responses

    Response samples

    Content type
    application/json
    { }

    Export Core - bulk export triples and embeddings

    Export knowledge cores in bulk using streaming MessagePack format.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Export knowledge cores in bulk using streaming MessagePack format.

    Export Core Overview

    Bulk data export for knowledge graph:

      @@ -3082,21 +3222,21 @@ No response body - returns 202 Accepted.

    • Replication: Copy knowledge cores
    • Analysis: External processing
    -
    Authorizations:
    bearerAuth
    query Parameters
    id
    required
    string
    Example: id=core-123

    Knowledge core ID to export

    -
    user
    required
    string
    Example: user=alice

    User identifier

    -

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    WebSocket

    Authorizations:
    bearerAuth
    query Parameters
    id
    required
    string
    Example: id=core-123

    Knowledge core ID to export

    +
    user
    required
    string
    Example: user=alice

    User identifier

    +

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    WebSocket

    WebSocket interfaces

    -

    WebSocket - multiplexed service interface

    WebSocket - multiplexed service interface

    WebSocket interface providing multiplexed access to all TrustGraph services over a single persistent connection.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    WebSocket interface providing multiplexed access to all TrustGraph services over a single persistent connection.

    Overview

    The WebSocket API provides access to the same services as the REST API but with:

      @@ -3382,21 +3522,21 @@ See individual service documentation for detailed request/response formats.

    • True streaming: Bidirectional real-time communication
    • Efficient multiplexing: Concurrent operations without connection pooling
    -
    Authorizations:
    bearerAuth
    header Parameters
    Upgrade
    required
    string
    Value: "websocket"

    WebSocket upgrade header

    -
    Connection
    required
    string
    Value: "Upgrade"

    Connection upgrade header

    -

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    Metrics

    Authorizations:
    bearerAuth
    header Parameters
    Upgrade
    required
    string
    Value: "websocket"

    WebSocket upgrade header

    +
    Connection
    required
    string
    Value: "Upgrade"

    Connection upgrade header

    +

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    Metrics

    System metrics and monitoring

    -

    Metrics - Prometheus metrics endpoint

    Metrics - Prometheus metrics endpoint

    Proxy to Prometheus metrics for system monitoring.

    +" class="sc-iKGpAq sc-cCYyou dXXcln dHaogz">

    Proxy to Prometheus metrics for system monitoring.

    Metrics Overview

    Exposes system metrics via Prometheus format:

      @@ -3464,29 +3604,29 @@ metric_name{labelPath Parameter

      The {path} parameter allows querying specific Prometheus endpoints or metrics if the backend Prometheus supports it.

      -
    Authorizations:
    bearerAuth

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    Metrics - Prometheus metrics with path

    Proxy to Prometheus metrics with optional path parameter.

    -
    Authorizations:
    bearerAuth
    path Parameters
    path
    required
    string
    Example: query

    Path to specific metrics endpoint

    -

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }
    +
    Authorizations:
    bearerAuth

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }

    Metrics - Prometheus metrics with path

    Proxy to Prometheus metrics with optional path parameter.

    +
    Authorizations:
    bearerAuth
    path Parameters
    path
    required
    string
    Example: query

    Path to specific metrics endpoint

    +

    Responses

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }
    + const config = {"show":{"sidebar":true},"sidebar":{"showOperations":"byDefault"}}; + const appRoot = document.getElementById('root'); + AsyncApiStandalone.render( + { schema, config, }, appRoot + ); + \ No newline at end of file diff --git a/specs/api/components/common/RdfValue.yaml b/specs/api/components/common/RdfValue.yaml index 5ed7c992..033e6789 100644 --- a/specs/api/components/common/RdfValue.yaml +++ b/specs/api/components/common/RdfValue.yaml @@ -1,14 +1,60 @@ type: object -description: RDF value - can be entity/URI or literal -required: - - v - - e +description: | + RDF Term - typed representation of a value in the knowledge graph. + + Term types (discriminated by `t` field): + - `i`: IRI (URI reference) + - `l`: Literal (string value, optionally with datatype or language tag) + - `r`: Quoted triple (RDF-star reification) + - `b`: Blank node properties: + t: + type: string + description: Term type discriminator + enum: [i, l, r, b] + example: i + i: + type: string + description: IRI value (when t=i) + example: http://example.com/Person1 v: type: string - description: Value (URI or literal text) - example: https://example.com/entity1 - e: - type: boolean - description: True if entity/URI, false if literal - example: true + description: Literal value (when t=l) + example: John Doe + d: + type: string + description: Datatype IRI for literal (when t=l, optional) + example: http://www.w3.org/2001/XMLSchema#integer + l: + type: string + description: Language tag for literal (when t=l, optional) + example: en + r: + type: object + description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure + properties: + s: + type: object + description: Subject term + p: + type: object + description: Predicate term + o: + type: object + description: Object term +required: + - t +examples: + - description: IRI term + value: + t: i + i: http://schema.org/name + - description: Literal term + value: + t: l + v: John Doe + - description: Literal with language tag + value: + t: l + v: Bonjour + l: fr diff --git a/specs/api/components/common/Triple.yaml b/specs/api/components/common/Triple.yaml index 142be0e9..5e15436b 100644 --- a/specs/api/components/common/Triple.yaml +++ b/specs/api/components/common/Triple.yaml @@ -1,5 +1,6 @@ type: object -description: RDF triple (subject-predicate-object) +description: | + RDF triple (subject-predicate-object), optionally scoped to a named graph. required: - s - p @@ -14,3 +15,7 @@ properties: o: $ref: './RdfValue.yaml' description: Object + g: + type: string + description: Named graph URI (optional) + example: urn:graph:source diff --git a/specs/api/components/schemas/agent/AgentResponse.yaml b/specs/api/components/schemas/agent/AgentResponse.yaml index 86d636b5..6fe6abbd 100644 --- a/specs/api/components/schemas/agent/AgentResponse.yaml +++ b/specs/api/components/schemas/agent/AgentResponse.yaml @@ -9,12 +9,26 @@ properties: - action - observation - answer + - final-answer - error example: answer content: type: string description: Chunk content (streaming mode only) example: Paris is the capital of France. + message_type: + type: string + description: Message type - "chunk" for agent chunks, "explain" for explainability events + enum: [chunk, explain] + example: chunk + explain_id: + type: string + description: Explainability node URI (for explain messages) + example: urn:trustgraph:agent:abc123 + explain_graph: + type: string + description: Named graph containing the explainability data + example: urn:graph:retrieval end-of-message: type: boolean description: Current chunk type is complete (streaming mode) diff --git a/specs/api/components/schemas/common/RdfValue.yaml b/specs/api/components/schemas/common/RdfValue.yaml index ce8b4c08..033e6789 100644 --- a/specs/api/components/schemas/common/RdfValue.yaml +++ b/specs/api/components/schemas/common/RdfValue.yaml @@ -1,21 +1,60 @@ type: object description: | - RDF value - represents either a URI/entity or a literal value. + RDF Term - typed representation of a value in the knowledge graph. - When `e` is true, `v` must be a full URI (e.g., http://schema.org/name). - When `e` is false, `v` is a literal value (string, number, etc.). + Term types (discriminated by `t` field): + - `i`: IRI (URI reference) + - `l`: Literal (string value, optionally with datatype or language tag) + - `r`: Quoted triple (RDF-star reification) + - `b`: Blank node properties: + t: + type: string + description: Term type discriminator + enum: [i, l, r, b] + example: i + i: + type: string + description: IRI value (when t=i) + example: http://example.com/Person1 v: type: string - description: The value - full URI when e=true, literal when e=false - example: http://example.com/Person1 - e: - type: boolean - description: True if entity/URI, false if literal value - example: true + description: Literal value (when t=l) + example: John Doe + d: + type: string + description: Datatype IRI for literal (when t=l, optional) + example: http://www.w3.org/2001/XMLSchema#integer + l: + type: string + description: Language tag for literal (when t=l, optional) + example: en + r: + type: object + description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure + properties: + s: + type: object + description: Subject term + p: + type: object + description: Predicate term + o: + type: object + description: Object term required: - - v - - e -example: - v: http://schema.org/name - e: true + - t +examples: + - description: IRI term + value: + t: i + i: http://schema.org/name + - description: Literal term + value: + t: l + v: John Doe + - description: Literal with language tag + value: + t: l + v: Bonjour + l: fr diff --git a/specs/api/components/schemas/common/Triple.yaml b/specs/api/components/schemas/common/Triple.yaml index 1f72b89a..0c36a91c 100644 --- a/specs/api/components/schemas/common/Triple.yaml +++ b/specs/api/components/schemas/common/Triple.yaml @@ -1,6 +1,7 @@ type: object description: | - RDF triple representing a subject-predicate-object statement in the knowledge graph. + RDF triple representing a subject-predicate-object statement in the knowledge graph, + optionally scoped to a named graph. Example: (Person1) -[has name]-> ("John Doe") properties: @@ -13,17 +14,26 @@ properties: o: $ref: './RdfValue.yaml' description: Object - the value or target entity + g: + type: string + description: | + Named graph URI (optional). When absent, the triple is in the default graph. + Well-known graphs: + - (empty/absent): Core knowledge facts + - urn:graph:source: Extraction provenance + - urn:graph:retrieval: Query-time explainability + example: urn:graph:source required: - s - p - o example: s: - v: http://example.com/Person1 - e: true + t: i + i: http://example.com/Person1 p: - v: http://schema.org/name - e: true + t: i + i: http://schema.org/name o: + t: l v: John Doe - e: false diff --git a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryResponse.yaml b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryResponse.yaml index 6b1d811d..bb72792e 100644 --- a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryResponse.yaml +++ b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryResponse.yaml @@ -1,12 +1,22 @@ type: object -description: Document embeddings query response +description: Document embeddings query response with matching chunks and similarity scores properties: chunks: type: array - description: Similar document chunks (text strings) + description: Matching document chunks with similarity scores items: - type: string + type: object + properties: + chunk_id: + type: string + description: Chunk identifier URI + example: "urn:trustgraph:chunk:abc123" + score: + type: number + description: Similarity score (higher is more similar) + example: 0.89 example: - - "Quantum computing uses quantum mechanics principles for computation..." - - "Neural networks are computing systems inspired by biological neurons..." - - "Machine learning algorithms learn patterns from data..." + - chunk_id: "urn:trustgraph:chunk:abc123" + score: 0.95 + - chunk_id: "urn:trustgraph:chunk:def456" + score: 0.82 diff --git a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryResponse.yaml b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryResponse.yaml index 80692a12..94bce275 100644 --- a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryResponse.yaml +++ b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryResponse.yaml @@ -1,12 +1,21 @@ type: object -description: Graph embeddings query response +description: Graph embeddings query response with matching entities and similarity scores properties: entities: type: array - description: Similar entities (RDF values) + description: Matching graph entities with similarity scores items: - $ref: '../../common/RdfValue.yaml' + type: object + properties: + entity: + $ref: '../../common/RdfValue.yaml' + description: Matching graph entity + score: + type: number + description: Similarity score (higher is more similar) + example: 0.92 example: - - {v: "https://example.com/person/alice", e: true} - - {v: "https://example.com/person/bob", e: true} - - {v: "https://example.com/concept/quantum", e: true} + - entity: {t: i, i: "https://example.com/person/alice"} + score: 0.95 + - entity: {t: i, i: "https://example.com/concept/quantum"} + score: 0.82 diff --git a/specs/api/components/schemas/query/TriplesQueryRequest.yaml b/specs/api/components/schemas/query/TriplesQueryRequest.yaml index 88b0a1eb..d49e0300 100644 --- a/specs/api/components/schemas/query/TriplesQueryRequest.yaml +++ b/specs/api/components/schemas/query/TriplesQueryRequest.yaml @@ -28,3 +28,23 @@ properties: description: Collection to query default: default example: research + g: + type: string + description: | + Named graph filter (optional). + - Omitted/null: all graphs + - Empty string: default graph only + - URI string: specific named graph (e.g., urn:graph:source, urn:graph:retrieval) + example: urn:graph:source + streaming: + type: boolean + description: Enable streaming response delivery + default: false + example: true + batch-size: + type: integer + description: Number of triples per streaming batch + default: 20 + minimum: 1 + maximum: 1000 + example: 50 diff --git a/specs/api/components/schemas/rag/DocumentRagResponse.yaml b/specs/api/components/schemas/rag/DocumentRagResponse.yaml index 6a0166e7..1b275c3a 100644 --- a/specs/api/components/schemas/rag/DocumentRagResponse.yaml +++ b/specs/api/components/schemas/rag/DocumentRagResponse.yaml @@ -1,13 +1,31 @@ type: object -description: Document RAG response +description: Document RAG response message properties: + message_type: + type: string + description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events + enum: [chunk, explain] + example: chunk response: type: string - description: Generated response based on retrieved documents - example: The research papers found three key findings... + description: Generated response text (for chunk messages) + example: Based on the policy documents, customers can return items within 30 days... + explain_id: + type: string + description: Explainability node URI (for explain messages) + example: urn:trustgraph:question:abc123 + explain_graph: + type: string + description: Named graph containing the explainability data + example: urn:graph:retrieval end-of-stream: type: boolean - description: Indicates streaming is complete (streaming mode) + description: Indicates LLM response stream is complete + default: false + example: true + end_of_session: + type: boolean + description: Indicates entire session is complete (all messages sent) default: false example: true error: diff --git a/specs/api/components/schemas/rag/GraphRagResponse.yaml b/specs/api/components/schemas/rag/GraphRagResponse.yaml index 75f4f059..47513fe1 100644 --- a/specs/api/components/schemas/rag/GraphRagResponse.yaml +++ b/specs/api/components/schemas/rag/GraphRagResponse.yaml @@ -1,13 +1,31 @@ type: object -description: Graph RAG response +description: Graph RAG response message properties: + message_type: + type: string + description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events + enum: [chunk, explain] + example: chunk response: type: string - description: Generated response based on retrieved knowledge graph + description: Generated response text (for chunk messages) example: Quantum physics and computer science intersect in quantum computing... - end-of-stream: + explain_id: + type: string + description: Explainability node URI (for explain messages) + example: urn:trustgraph:question:abc123 + explain_graph: + type: string + description: Named graph containing the explainability data + example: urn:graph:retrieval + end_of_stream: type: boolean - description: Indicates streaming is complete (streaming mode) + description: Indicates LLM response stream is complete + default: false + example: true + end_of_session: + type: boolean + description: Indicates entire session is complete (all messages sent) default: false example: true error: diff --git a/specs/api/openapi.yaml b/specs/api/openapi.yaml index 4196f9ec..982a7cc4 100644 --- a/specs/api/openapi.yaml +++ b/specs/api/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: TrustGraph API Gateway - version: "1.8" + version: "2.1" description: | REST API for TrustGraph - an AI-powered knowledge graph and RAG system. @@ -28,7 +28,7 @@ info: Require running flow instance, accessed via `/api/v1/flow/{flow}/service/{kind}`: - AI services: agent, text-completion, prompt, RAG (document/graph) - Embeddings: embeddings, graph-embeddings, document-embeddings - - Query: triples, objects, nlp-query, structured-query + - Query: triples, rows, nlp-query, structured-query, row-embeddings - Data loading: text-load, document-load - Utilities: mcp-tool, structured-diag @@ -140,6 +140,10 @@ paths: /api/v1/flow/{flow}/service/document-load: $ref: './paths/flow/document-load.yaml' + # Document streaming + /api/v1/document-stream: + $ref: './paths/document-stream.yaml' + # Import/Export endpoints /api/v1/import-core: $ref: './paths/import-core.yaml' diff --git a/specs/api/paths/document-stream.yaml b/specs/api/paths/document-stream.yaml new file mode 100644 index 00000000..5f6a11a7 --- /dev/null +++ b/specs/api/paths/document-stream.yaml @@ -0,0 +1,53 @@ +get: + tags: + - Import/Export + summary: Stream document content from library + description: | + Streams the raw content of a document stored in the library. + Returns the document content in chunked transfer encoding. + + ## Parameters + + - `user`: User identifier (required) + - `document-id`: Document IRI to retrieve (required) + - `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB) + + operationId: documentStream + security: + - bearerAuth: [] + parameters: + - name: user + in: query + required: true + schema: + type: string + description: User identifier + example: trustgraph + - name: document-id + in: query + required: true + schema: + type: string + description: Document IRI to retrieve + example: "urn:trustgraph:doc:abc123" + - name: chunk-size + in: query + required: false + schema: + type: integer + default: 1048576 + description: Chunk size in bytes (default 1MB) + responses: + '200': + description: Document content streamed as raw bytes + content: + application/octet-stream: + schema: + type: string + format: binary + '400': + description: Missing required parameters + '401': + $ref: '../components/responses/Unauthorized.yaml' + '500': + $ref: '../components/responses/Error.yaml' diff --git a/specs/build-docs.sh b/specs/build-docs.sh index 3425b339..5e156913 100755 --- a/specs/build-docs.sh +++ b/specs/build-docs.sh @@ -24,7 +24,7 @@ echo # Build WebSocket API documentation echo "Building WebSocket API documentation (AsyncAPI)..." cd ../websocket -npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template@3.0.0 --use-new-generator -o /tmp/asyncapi-build -p singleFile=true --force-write +npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template -o /tmp/asyncapi-build -p singleFile=true --force-write mv /tmp/asyncapi-build/index.html ../../docs/websocket.html rm -rf /tmp/asyncapi-build echo "✓ WebSocket API docs generated: docs/websocket.html" diff --git a/specs/websocket/asyncapi.yaml b/specs/websocket/asyncapi.yaml index 43204aa7..056cc055 100644 --- a/specs/websocket/asyncapi.yaml +++ b/specs/websocket/asyncapi.yaml @@ -2,7 +2,7 @@ asyncapi: 3.0.0 info: title: TrustGraph WebSocket API - version: "1.8" + version: "2.1" description: | WebSocket API for TrustGraph - providing multiplexed, asynchronous access to all services. @@ -31,7 +31,7 @@ info: **Flow-Hosted Services** (require `flow` parameter): - agent, text-completion, prompt, document-rag, graph-rag - embeddings, graph-embeddings, document-embeddings - - triples, objects, nlp-query, structured-query, structured-diag + - triples, rows, nlp-query, structured-query, structured-diag, row-embeddings - text-load, document-load, mcp-tool ## Schema Reuse diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index e82ccd98..c474af29 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -95,8 +95,7 @@ def sample_message_data(): "Metadata": { "id": "test-doc-123", "user": "test_user", - "collection": "test_collection", - "metadata": [] + "collection": "test_collection" }, "Term": { "type": IRI, diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index e0939aaa..c7d6369a 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services import pytest from unittest.mock import MagicMock -from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error from trustgraph.messaging.translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator @@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract: """Test that DocumentEmbeddingsRequest has expected fields""" # Create a request request = DocumentEmbeddingsRequest( - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], limit=10, user="test_user", collection="test_collection" ) - + # Verify all expected fields exist - assert hasattr(request, 'vectors') + assert hasattr(request, 'vector') assert hasattr(request, 'limit') assert hasattr(request, 'user') assert hasattr(request, 'collection') - + # Verify field values - assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert request.vector == [0.1, 0.2, 0.3] assert request.limit == 10 assert request.user == "test_user" assert request.collection == "test_collection" @@ -41,18 +41,18 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_to_pulsar(self): """Test request translator converts dict to Pulsar schema""" translator = DocumentEmbeddingsRequestTranslator() - + data = { - "vectors": [[0.1, 0.2], [0.3, 0.4]], + "vector": [0.1, 0.2, 0.3, 0.4], "limit": 5, "user": "custom_user", "collection": "custom_collection" } - + result = translator.to_pulsar(data) - + assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] + assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.limit == 5 assert result.user == "custom_user" assert result.collection == "custom_collection" @@ -60,16 +60,16 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_to_pulsar_with_defaults(self): """Test request translator uses correct defaults""" translator = DocumentEmbeddingsRequestTranslator() - + data = { - "vectors": [[0.1, 0.2]] + "vector": [0.1, 0.2] # No limit, user, or collection provided } - + result = translator.to_pulsar(data) - + assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2]] + assert result.vector == [0.1, 0.2] assert result.limit == 10 # Default assert result.user == "trustgraph" # Default assert result.collection == "default" # Default @@ -77,18 +77,18 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_from_pulsar(self): """Test request translator converts Pulsar schema to dict""" translator = DocumentEmbeddingsRequestTranslator() - + request = DocumentEmbeddingsRequest( - vectors=[[0.5, 0.6]], + vector=[0.5, 0.6], limit=20, user="test_user", collection="test_collection" ) - + result = translator.from_pulsar(request) - + assert isinstance(result, dict) - assert result["vectors"] == [[0.5, 0.6]] + assert result["vector"] == [0.5, 0.6] assert result["limit"] == 20 assert result["user"] == "test_user" assert result["collection"] == "test_collection" @@ -102,16 +102,22 @@ class TestDocumentEmbeddingsResponseContract: # Create a response with chunks response = DocumentEmbeddingsResponse( error=None, - chunks=["chunk1", "chunk2", "chunk3"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8), + ChunkMatch(chunk_id="chunk3", score=0.7) + ] ) - + # Verify all expected fields exist assert hasattr(response, 'error') assert hasattr(response, 'chunks') - + # Verify field values assert response.error is None - assert response.chunks == ["chunk1", "chunk2", "chunk3"] + assert len(response.chunks) == 3 + assert response.chunks[0].chunk_id == "chunk1" + assert response.chunks[0].score == 0.9 def test_response_schema_with_error(self): """Test response schema with error""" @@ -119,52 +125,47 @@ class TestDocumentEmbeddingsResponseContract: type="query_error", message="Database connection failed" ) - + response = DocumentEmbeddingsResponse( error=error, - chunks=None + chunks=[] ) - + assert response.error == error - assert response.chunks is None + assert response.chunks == [] def test_response_translator_from_pulsar_with_chunks(self): """Test response translator converts Pulsar schema with chunks to dict""" translator = DocumentEmbeddingsResponseTranslator() - + response = DocumentEmbeddingsResponse( error=None, - chunks=["doc1", "doc2", "doc3"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85), + ChunkMatch(chunk_id="doc3/c3", score=0.75) + ] ) - - result = translator.from_pulsar(response) - - assert isinstance(result, dict) - assert "chunks" in result - assert result["chunks"] == ["doc1", "doc2", "doc3"] - def test_response_translator_from_pulsar_with_bytes(self): - """Test response translator handles byte chunks correctly""" - translator = DocumentEmbeddingsResponseTranslator() - - response = MagicMock() - response.chunks = [b"byte_chunk1", b"byte_chunk2"] - result = translator.from_pulsar(response) - + assert isinstance(result, dict) assert "chunks" in result - assert result["chunks"] == ["byte_chunk1", "byte_chunk2"] + assert len(result["chunks"]) == 3 + assert result["chunks"][0]["chunk_id"] == "doc1/c1" + assert result["chunks"][0]["score"] == 0.95 def test_response_translator_from_pulsar_with_empty_chunks(self): """Test response translator handles empty chunks list""" translator = DocumentEmbeddingsResponseTranslator() - - response = MagicMock() - response.chunks = [] - + + response = DocumentEmbeddingsResponse( + error=None, + chunks=[] + ) + result = translator.from_pulsar(response) - + assert isinstance(result, dict) assert "chunks" in result assert result["chunks"] == [] @@ -172,37 +173,41 @@ class TestDocumentEmbeddingsResponseContract: def test_response_translator_from_pulsar_with_none_chunks(self): """Test response translator handles None chunks""" translator = DocumentEmbeddingsResponseTranslator() - + response = MagicMock() response.chunks = None - + result = translator.from_pulsar(response) - + assert isinstance(result, dict) assert "chunks" not in result or result.get("chunks") is None def test_response_translator_from_response_with_completion(self): """Test response translator with completion flag""" translator = DocumentEmbeddingsResponseTranslator() - + response = DocumentEmbeddingsResponse( error=None, - chunks=["chunk1", "chunk2"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8) + ] ) - + result, is_final = translator.from_response_with_completion(response) - + assert isinstance(result, dict) assert "chunks" in result - assert result["chunks"] == ["chunk1", "chunk2"] + assert len(result["chunks"]) == 2 + assert result["chunks"][0]["chunk_id"] == "chunk1" assert is_final is True # Document embeddings responses are always final def test_response_translator_to_pulsar_not_implemented(self): """Test that to_pulsar raises NotImplementedError for responses""" translator = DocumentEmbeddingsResponseTranslator() - + with pytest.raises(NotImplementedError): - translator.to_pulsar({"chunks": ["test"]}) + translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]}) class TestDocumentEmbeddingsMessageCompatibility: @@ -212,26 +217,29 @@ class TestDocumentEmbeddingsMessageCompatibility: """Test complete request-response flow maintains data integrity""" # Create request request_data = { - "vectors": [[0.1, 0.2, 0.3]], + "vector": [0.1, 0.2, 0.3], "limit": 5, "user": "test_user", "collection": "test_collection" } - + # Convert to Pulsar request req_translator = DocumentEmbeddingsRequestTranslator() pulsar_request = req_translator.to_pulsar(request_data) - + # Simulate service processing and creating response response = DocumentEmbeddingsResponse( error=None, - chunks=["relevant chunk 1", "relevant chunk 2"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85) + ] ) - + # Convert response back to dict resp_translator = DocumentEmbeddingsResponseTranslator() response_data = resp_translator.from_pulsar(response) - + # Verify data integrity assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(response_data, dict) @@ -245,17 +253,18 @@ class TestDocumentEmbeddingsMessageCompatibility: type="vector_db_error", message="Collection not found" ) - + response = DocumentEmbeddingsResponse( error=error, - chunks=None + chunks=[] ) - + # Convert response to dict translator = DocumentEmbeddingsResponseTranslator() response_data = translator.from_pulsar(response) - + # Verify error handling assert isinstance(response_data, dict) # The translator doesn't include error in the dict, only chunks - assert "chunks" not in response_data or response_data.get("chunks") is None \ No newline at end of file + assert "chunks" in response_data + assert response_data["chunks"] == [] diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 746ebaed..695fef14 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -401,25 +401,6 @@ class TestMetadataMessageContracts: assert metadata.id == "test-doc-123" assert metadata.user == "test_user" assert metadata.collection == "test_collection" - assert isinstance(metadata.metadata, list) - - def test_metadata_with_triples_contract(self, sample_message_data): - """Test Metadata with embedded triples contract""" - # Arrange - triple = Triple(**sample_message_data["Triple"]) - metadata_data = { - "id": "doc-with-triples", - "user": "test_user", - "collection": "test_collection", - "metadata": [triple] - } - - # Act & Assert - assert validate_schema_contract(Metadata, metadata_data) - - metadata = Metadata(**metadata_data) - assert len(metadata.metadata) == 1 - assert metadata.metadata[0].s.iri == "http://example.com/subject" def test_error_schema_contract(self): """Test Error schema contract""" diff --git a/tests/contract/test_rows_cassandra_contracts.py b/tests/contract/test_rows_cassandra_contracts.py index d1a8ba26..bf85b9fb 100644 --- a/tests/contract/test_rows_cassandra_contracts.py +++ b/tests/contract/test_rows_cassandra_contracts.py @@ -24,7 +24,6 @@ class TestRowsCassandraContracts: id="test-doc-001", user="test_user", collection="test_collection", - metadata=[] ) test_object = ExtractedObject( @@ -50,7 +49,6 @@ class TestRowsCassandraContracts: assert hasattr(test_object.metadata, 'id') assert hasattr(test_object.metadata, 'user') assert hasattr(test_object.metadata, 'collection') - assert hasattr(test_object.metadata, 'metadata') # Verify types assert isinstance(test_object.schema_name, str) @@ -154,7 +152,6 @@ class TestRowsCassandraContracts: id="serial-001", user="test_user", collection="test_coll", - metadata=[] ), schema_name="test_schema", values=[{"field1": "value1", "field2": "123"}], @@ -234,7 +231,6 @@ class TestRowsCassandraContracts: id="meta-001", user="user123", # -> keyspace collection="coll456", # -> partition key - metadata=[{"key": "value"}] ), schema_name="table789", # -> table name values=[{"field": "value"}], @@ -262,7 +258,6 @@ class TestRowsCassandraContractsBatch: id="batch-doc-001", user="test_user", collection="test_collection", - metadata=[] ) batch_object = ExtractedObject( @@ -308,10 +303,9 @@ class TestRowsCassandraContractsBatch: test_metadata = Metadata( id="empty-batch-001", user="test_user", - collection="test_collection", - metadata=[] + collection="test_collection", ) - + empty_batch_object = ExtractedObject( metadata=test_metadata, schema_name="empty_schema", @@ -332,9 +326,8 @@ class TestRowsCassandraContractsBatch: id="single-batch-001", user="test_user", collection="test_collection", - metadata=[] ) - + single_batch_object = ExtractedObject( metadata=test_metadata, schema_name="customer_records", @@ -362,12 +355,11 @@ class TestRowsCassandraContractsBatch: id="batch-serial-001", user="test_user", collection="test_coll", - metadata=[] ), schema_name="test_schema", values=[ {"field1": "value1", "field2": "123"}, - {"field1": "value2", "field2": "456"}, + {"field1": "value2", "field2": "456"}, {"field1": "value3", "field2": "789"} ], confidence=0.92, @@ -436,9 +428,8 @@ class TestRowsCassandraContractsBatch: id="partition-test-001", user="consistent_user", # Same keyspace collection="consistent_collection", # Same partition - metadata=[] ) - + batch_object = ExtractedObject( metadata=test_metadata, schema_name="partition_test", diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index 71ccd787..d8f4c5cb 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -95,9 +95,8 @@ class TestStructuredDataSchemaContracts: id="structured-data-001", user="test_user", collection="test_collection", - metadata=[] ) - + # Act submission = StructuredDataSubmission( metadata=metadata, @@ -121,9 +120,8 @@ class TestStructuredDataSchemaContracts: id="extracted-obj-001", user="test_user", collection="test_collection", - metadata=[] ) - + # Act obj = ExtractedObject( metadata=metadata, @@ -147,9 +145,8 @@ class TestStructuredDataSchemaContracts: id="extracted-batch-001", user="test_user", collection="test_collection", - metadata=[] ) - + # Act - create object with multiple values obj = ExtractedObject( metadata=metadata, @@ -180,11 +177,10 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-empty-001", - user="test_user", + user="test_user", collection="test_collection", - metadata=[] ) - + # Act - create object with empty values array obj = ExtractedObject( metadata=metadata, @@ -283,13 +279,12 @@ class TestStructuredEmbeddingsContracts: id="struct-embed-001", user="test_user", collection="test_collection", - metadata=[] ) - + # Act embedding = StructuredObjectEmbedding( metadata=metadata, - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], schema_name="customer_records", object_id="customer_123", field_embeddings={ @@ -301,7 +296,7 @@ class TestStructuredEmbeddingsContracts: # Assert assert embedding.schema_name == "customer_records" assert embedding.object_id == "customer_123" - assert len(embedding.vectors) == 2 + assert len(embedding.vector) == 3 assert len(embedding.field_embeddings) == 2 assert "name" in embedding.field_embeddings @@ -313,7 +308,7 @@ class TestStructuredDataSerializationContracts: def test_structured_data_submission_serialization(self): """Test StructuredDataSubmission serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + metadata = Metadata(id="test", user="user", collection="col") submission_data = { "metadata": metadata, "format": "json", @@ -328,7 +323,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_serialization(self): """Test ExtractedObject serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + metadata = Metadata(id="test", user="user", collection="col") object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -378,7 +373,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_batch_serialization(self): """Test ExtractedObject batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + metadata = Metadata(id="test", user="user", collection="col") batch_object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -397,7 +392,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_empty_batch_serialization(self): """Test ExtractedObject empty batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + metadata = Metadata(id="test", user="user", collection="col") empty_batch_data = { "metadata": metadata, "schema_name": "test_schema", diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index c01156ae..dc7d5748 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry class TestRAGTranslatorCompletionFlags: """Contract tests for RAG response translator completion flags""" - def test_graph_rag_translator_is_final_with_end_of_stream_true(self): + def test_graph_rag_translator_is_final_with_end_of_session_true(self): """ Test that GraphRagResponseTranslator returns is_final=True - when end_of_stream=True. + when end_of_session=True. """ # Arrange translator = TranslatorRegistry.get_response_translator("graph-rag") response = GraphRagResponse( response="A small domesticated mammal.", + message_type="chunk", end_of_stream=True, + end_of_session=True, error=None ) @@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is True, "is_final must be True when end_of_stream=True" + assert is_final is True, "is_final must be True when end_of_session=True" assert response_dict["response"] == "A small domesticated mammal." - assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is True + assert response_dict["message_type"] == "chunk" - def test_graph_rag_translator_is_final_with_end_of_stream_false(self): + def test_graph_rag_translator_is_final_with_end_of_session_false(self): """ Test that GraphRagResponseTranslator returns is_final=False - when end_of_stream=False. + when end_of_session=False (even if end_of_stream=True). """ # Arrange translator = TranslatorRegistry.get_response_translator("graph-rag") response = GraphRagResponse( response="Chunk 1", + message_type="chunk", end_of_stream=False, + end_of_session=False, error=None ) @@ -55,20 +60,67 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is False, "is_final must be False when end_of_stream=False" + assert is_final is False, "is_final must be False when end_of_session=False" assert response_dict["response"] == "Chunk 1" - assert response_dict["end_of_stream"] is False + assert response_dict["end_of_session"] is False - def test_document_rag_translator_is_final_with_end_of_stream_true(self): + def test_graph_rag_translator_provenance_message(self): + """ + Test that GraphRagResponseTranslator handles provenance messages. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="", + message_type="explain", + explain_id="urn:trustgraph:session:abc123", + end_of_stream=False, + end_of_session=False, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False + assert response_dict["message_type"] == "explain" + assert response_dict["explain_id"] == "urn:trustgraph:session:abc123" + + def test_graph_rag_translator_end_of_stream_not_final(self): + """ + Test that end_of_stream=True alone does NOT make is_final=True. + The session continues with provenance messages after LLM stream completes. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="Final chunk", + message_type="chunk", + end_of_stream=True, + end_of_session=False, # Session continues with provenance + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "end_of_stream=True should NOT make is_final=True" + assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is False + + def test_document_rag_translator_is_final_with_end_of_session_true(self): """ Test that DocumentRagResponseTranslator returns is_final=True - when end_of_stream=True. + when end_of_session=True. """ # Arrange translator = TranslatorRegistry.get_response_translator("document-rag") response = DocumentRagResponse( response="A document about cats.", end_of_stream=True, + end_of_session=True, error=None ) @@ -76,9 +128,31 @@ class TestRAGTranslatorCompletionFlags: response_dict, is_final = translator.from_response_with_completion(response) # Assert - assert is_final is True, "is_final must be True when end_of_stream=True" + assert is_final is True, "is_final must be True when end_of_session=True" assert response_dict["response"] == "A document about cats." + assert response_dict["end_of_session"] is True + + def test_document_rag_translator_end_of_stream_not_final(self): + """ + Test that end_of_stream=True alone does NOT make is_final=True. + The session continues with provenance messages after LLM stream completes. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("document-rag") + response = DocumentRagResponse( + response="Final chunk", + end_of_stream=True, + end_of_session=False, # Session continues with provenance + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "end_of_stream=True should NOT make is_final=True" assert response_dict["end_of_stream"] is True + assert response_dict["end_of_session"] is False def test_document_rag_translator_is_final_with_end_of_stream_false(self): """ diff --git a/tests/integration/test_agent_kg_extraction_integration.py b/tests/integration/test_agent_kg_extraction_integration.py index 849547c8..579498db 100644 --- a/tests/integration/test_agent_kg_extraction_integration.py +++ b/tests/integration/test_agent_kg_extraction_integration.py @@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse -from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF +from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL from trustgraph.template.prompt_manager import PromptManager @@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration: agent_client = AsyncMock() # Mock successful agent response in JSONL format - def mock_agent_response(recipient, question): + def mock_agent_response(question): # Simulate agent processing and return structured JSONL response mock_response = MagicMock() mock_response.error = None @@ -76,13 +76,6 @@ class TestAgentKgExtractionIntegration: chunk=text.encode('utf-8'), metadata=Metadata( id="doc123", - metadata=[ - Triple( - s=Term(type=IRI, iri="doc123"), - p=Term(type=IRI, iri="http://example.org/type"), - o=Term(type=LITERAL, value="document") - ) - ] ) ) @@ -131,16 +124,12 @@ class TestAgentKgExtractionIntegration: # Get agent response (the mock returns a string directly) agent_client = flow("agent-request") - agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt) + agent_response = agent_client.invoke(question=prompt) # Parse and process extraction_data = extractor.parse_jsonl(agent_response) - triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata) - - # Add metadata triples - for t in v.metadata.metadata: - triples.append(t) - + triples, entity_contexts, extracted_triples = extractor.process_extraction_data(extraction_data, v.metadata) + # Emit outputs if triples: await extractor.emit_triples(flow("triples"), v.metadata, triples) @@ -185,10 +174,6 @@ class TestAgentKgExtractionIntegration: label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL] assert len(label_triples) >= 2 # Should have labels for entities - # Check subject-of relationships - subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF] - assert len(subject_of_triples) >= 2 # Entities should be linked to document - # Verify entity contexts were emitted entity_contexts_publisher = mock_flow_context("entity-contexts") entity_contexts_publisher.send.assert_called_once() @@ -208,7 +193,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock agent error response agent_client = mock_flow_context("agent-request") - def mock_error_response(recipient, question): + def mock_error_response(question): # Simulate agent error by raising an exception raise RuntimeError("Agent processing failed") @@ -230,7 +215,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock invalid JSON response agent_client = mock_flow_context("agent-request") - def mock_invalid_json_response(recipient, question): + def mock_invalid_json_response(question): return "This is not valid JSON at all" agent_client.invoke = mock_invalid_json_response @@ -242,9 +227,9 @@ class TestAgentKgExtractionIntegration: # Act - JSONL parsing is lenient, invalid lines are skipped await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context) - # Assert - should emit triples (with just metadata) but no entity contexts + # Assert - with no valid extraction data, nothing is emitted triples_publisher = mock_flow_context("triples") - triples_publisher.send.assert_called_once() + triples_publisher.send.assert_not_called() entity_contexts_publisher = mock_flow_context("entity-contexts") entity_contexts_publisher.send.assert_not_called() @@ -255,7 +240,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock empty extraction response agent_client = mock_flow_context("agent-request") - def mock_empty_response(recipient, question): + def mock_empty_response(question): # Return empty JSONL (just empty/whitespace) return '' @@ -268,17 +253,12 @@ class TestAgentKgExtractionIntegration: # Act await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context) - # Assert - # Should still emit outputs (even if empty) to maintain flow consistency + # Assert - with empty extraction results, nothing is emitted triples_publisher = mock_flow_context("triples") entity_contexts_publisher = mock_flow_context("entity-contexts") - - # Triples should include metadata triples at minimum - triples_publisher.send.assert_called_once() - sent_triples = triples_publisher.send.call_args[0][0] - assert isinstance(sent_triples, Triples) - - # Entity contexts should not be sent if empty + + # No triples or entity contexts emitted for empty results + triples_publisher.send.assert_not_called() entity_contexts_publisher.send.assert_not_called() @pytest.mark.asyncio @@ -287,7 +267,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock malformed extraction response agent_client = mock_flow_context("agent-request") - def mock_malformed_response(recipient, question): + def mock_malformed_response(question): # JSONL with definition missing required field return '{"type": "definition", "entity": "Missing Definition"}' @@ -308,12 +288,12 @@ class TestAgentKgExtractionIntegration: test_text = "Test text for prompt rendering" chunk = Chunk( chunk=test_text.encode('utf-8'), - metadata=Metadata(id="test-doc", metadata=[]) + metadata=Metadata(id="test-doc") ) agent_client = mock_flow_context("agent-request") - def capture_prompt(recipient, question): + def capture_prompt(question): # Verify the prompt contains the test text assert test_text in question return '' # Empty JSONL response @@ -340,13 +320,13 @@ class TestAgentKgExtractionIntegration: text = f"Test document {i} content" chunks.append(Chunk( chunk=text.encode('utf-8'), - metadata=Metadata(id=f"doc{i}", metadata=[]) + metadata=Metadata(id=f"doc{i}") )) agent_client = mock_flow_context("agent-request") responses = [] - def mock_response(recipient, question): + def mock_response(question): response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}' responses.append(response) return response @@ -375,12 +355,12 @@ class TestAgentKgExtractionIntegration: unicode_text = "Machine Learning (学习机器) は人工知能の一分野です。" chunk = Chunk( chunk=unicode_text.encode('utf-8'), - metadata=Metadata(id="unicode-doc", metadata=[]) + metadata=Metadata(id="unicode-doc") ) agent_client = mock_flow_context("agent-request") - def mock_unicode_response(recipient, question): + def mock_unicode_response(question): # Verify unicode text was properly decoded and included assert "学习机器" in question assert "人工知能" in question @@ -411,12 +391,12 @@ class TestAgentKgExtractionIntegration: large_text = "Machine Learning is important. " * 1000 # Repeat to create large text chunk = Chunk( chunk=large_text.encode('utf-8'), - metadata=Metadata(id="large-doc", metadata=[]) + metadata=Metadata(id="large-doc") ) agent_client = mock_flow_context("agent-request") - def mock_large_text_response(recipient, question): + def mock_large_text_response(question): # Verify large text was included assert len(question) > 10000 return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}' diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index f4f59444..0fedd2b5 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration: pulsar_client=AsyncMock(), max_iterations=3 ) - + # Mock the client method for structured query proc.client = MagicMock() - + + # Mock librarian to avoid hanging on save operations + proc.save_answer_content = AsyncMock(return_value=None) + return proc @pytest.fixture diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 3db22c4d..e9df05cf 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -9,6 +9,15 @@ Following the TEST_STRATEGY.md approach for integration testing. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch + + +# Sample chunk content for testing - maps chunk_id to content +CHUNK_CONTENT = { + "doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.", + "doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.", + "doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.", +} @pytest.mark.integration @@ -19,23 +28,35 @@ class TestDocumentRagIntegration: def mock_embeddings_client(self): """Mock embeddings client that returns realistic vector embeddings""" client = AsyncMock() + # New batch format: [[[vectors_for_text1], ...]] + # One text input returns one vector set containing two vectors client.embed.return_value = [ - [0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding - [0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing + [ + [0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text + [0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text + ] ] return client @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns realistic document chunks""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() + # Returns ChunkMatch objects with chunk_id and score client.query.return_value = [ - "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.", - "Deep learning uses neural networks with multiple layers to model complex patterns in data.", - "Supervised learning algorithms learn from labeled training data to make predictions on new data." + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) ] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + @pytest.fixture def mock_prompt_client(self): """Mock prompt client that generates realistic responses""" @@ -48,17 +69,19 @@ class TestDocumentRagIntegration: return client @pytest.fixture - def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client): + def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, + mock_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with mocked dependencies""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @pytest.mark.asyncio - async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client, + async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client): """Test complete DocumentRAG pipeline from query to response""" # Arrange @@ -76,15 +99,16 @@ class TestDocumentRagIntegration: ) # Assert - Verify service coordination - mock_embeddings_client.embed.assert_called_once_with(query) - + mock_embeddings_client.embed.assert_called_once_with([query]) + mock_doc_embeddings_client.query.assert_called_once_with( - [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], + vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], limit=doc_limit, user=user, collection=collection ) - + + # Documents are fetched from librarian using chunk_ids mock_prompt_client.document_prompt.assert_called_once_with( query=query, documents=[ @@ -101,17 +125,19 @@ class TestDocumentRagIntegration: assert "artificial intelligence" in result.lower() @pytest.mark.asyncio - async def test_document_rag_with_no_documents_found(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_with_no_documents_found(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG behavior when no documents are retrieved""" # Arrange - mock_doc_embeddings_client.query.return_value = [] # No documents found + mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) @@ -125,92 +151,98 @@ class TestDocumentRagIntegration: query="very obscure query", documents=[] ) - + assert result == "I couldn't find any relevant documents for your query." @pytest.mark.asyncio - async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when embeddings service fails""" # Arrange mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "Embeddings service unavailable" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_not_called() mock_prompt_client.document_prompt.assert_not_called() @pytest.mark.asyncio - async def test_document_rag_document_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_document_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when document service fails""" # Arrange mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "Document service connection failed" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() mock_prompt_client.document_prompt.assert_not_called() @pytest.mark.asyncio - async def test_document_rag_prompt_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_prompt_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when prompt service fails""" # Arrange mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "LLM service rate limited" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() mock_prompt_client.document_prompt.assert_called_once() @pytest.mark.asyncio - async def test_document_rag_with_different_document_limits(self, document_rag, + async def test_document_rag_with_different_document_limits(self, document_rag, mock_doc_embeddings_client): """Test DocumentRAG with various document limit configurations""" # Test different document limits test_cases = [1, 5, 10, 25, 50] - + for limit in test_cases: # Reset mock call history mock_doc_embeddings_client.reset_mock() - + # Act await document_rag.query(f"query with limit {limit}", doc_limit=limit) - + # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args @@ -230,14 +262,14 @@ class TestDocumentRagIntegration: for user, collection in test_scenarios: # Reset mock call history mock_doc_embeddings_client.reset_mock() - + # Act await document_rag.query( f"query from {user} in {collection}", user=user, collection=collection ) - + # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args @@ -245,19 +277,21 @@ class TestDocumentRagIntegration: assert call_args.kwargs['collection'] == collection @pytest.mark.asyncio - async def test_document_rag_verbose_logging(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client, + async def test_document_rag_verbose_logging(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk, caplog): """Test DocumentRAG verbose logging functionality""" import logging - + # Arrange - Configure logging to capture debug messages caplog.set_level(logging.DEBUG) - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @@ -269,25 +303,25 @@ class TestDocumentRagIntegration: assert "DocumentRag initialized" in log_messages assert "Constructing prompt..." in log_messages assert "Computing embeddings..." in log_messages - assert "Getting documents..." in log_messages + assert "chunks" in log_messages.lower() assert "Invoking LLM..." in log_messages assert "Query processing complete" in log_messages @pytest.mark.asyncio @pytest.mark.slow - async def test_document_rag_performance_with_large_document_set(self, document_rag, + async def test_document_rag_performance_with_large_document_set(self, document_rag, mock_doc_embeddings_client): """Test DocumentRAG performance with large document retrieval""" - # Arrange - Mock large document set (100 documents) - large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)] - mock_doc_embeddings_client.query.return_value = large_doc_set + # Arrange - Mock large chunk match set (100 chunks) + large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)] + mock_doc_embeddings_client.query.return_value = large_chunk_matches # Act import time start_time = time.time() - + result = await document_rag.query("performance test query", doc_limit=100) - + end_time = time.time() execution_time = end_time - start_time @@ -309,4 +343,4 @@ class TestDocumentRagIntegration: call_args = mock_doc_embeddings_client.query.call_args assert call_args.kwargs['user'] == "trustgraph" assert call_args.kwargs['collection'] == "default" - assert call_args.kwargs['limit'] == 20 \ No newline at end of file + assert call_args.kwargs['limit'] == 20 diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index 84061add..dad30a8f 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -8,12 +8,21 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, ) +# Sample chunk content for testing - maps chunk_id to content +CHUNK_CONTENT = { + "doc/c1": "Machine learning is a subset of AI.", + "doc/c2": "Deep learning uses neural networks.", + "doc/c3": "Supervised learning needs labeled data.", +} + + @pytest.mark.integration class TestDocumentRagStreaming: """Integration tests for DocumentRAG streaming""" @@ -22,20 +31,29 @@ class TestDocumentRagStreaming: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + # New batch format: [[[vectors_for_text1]]] + client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]] return client @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() + # Returns ChunkMatch objects with chunk_id and score client.query.return_value = [ - "Machine learning is a subset of AI.", - "Deep learning uses neural networks.", - "Supervised learning needs labeled data." + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) ] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + @pytest.fixture def mock_streaming_prompt_client(self, mock_streaming_llm_response): """Mock prompt client with streaming support""" @@ -66,12 +84,13 @@ class TestDocumentRagStreaming: @pytest.fixture def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client, - mock_streaming_prompt_client): + mock_streaming_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with streaming support""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_streaming_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @@ -190,7 +209,7 @@ class TestDocumentRagStreaming: mock_doc_embeddings_client): """Test streaming with no documents found""" # Arrange - mock_doc_embeddings_client.query.return_value = [] # No documents + mock_doc_embeddings_client.query.return_value = [] # No chunk_ids callback = AsyncMock() # Act diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index a0608819..5e3279e3 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term, IRI @pytest.mark.integration @@ -21,8 +22,12 @@ class TestGraphRagIntegration: def mock_embeddings_client(self): """Mock embeddings client that returns realistic vector embeddings""" client = AsyncMock() + # New batch format: [[[vectors_for_text1], ...]] + # One text input returns one vector set containing one vector client.embed.return_value = [ - [0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding + [ + [0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text + ] ] return client @@ -31,9 +36,9 @@ class TestGraphRagIntegration: """Mock graph embeddings client that returns realistic entities""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", - "http://trustgraph.ai/e/artificial-intelligence", - "http://trustgraph.ai/e/neural-networks" + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85) ] return client @@ -43,7 +48,7 @@ class TestGraphRagIntegration: client = AsyncMock() # Mock different queries return different triples - async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None): + async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20): # Mock label queries if p == "http://www.w3.org/2000/01/rdf-schema#label": if s == "http://trustgraph.ai/e/machine-learning": @@ -71,18 +76,37 @@ class TestGraphRagIntegration: return [] - client.query.side_effect = query_side_effect + client.query_stream.side_effect = query_stream_side_effect + # Also mock query for label lookups (maybe_label uses query, not query_stream) + client.query.side_effect = query_stream_side_effect return client @pytest.fixture def mock_prompt_client(self): - """Mock prompt client that generates realistic responses""" + """Mock prompt client that generates realistic responses for two-step process""" client = AsyncMock() - client.kg_prompt.return_value = ( - "Machine learning is a subset of artificial intelligence that enables computers " - "to learn from data without being explicitly programmed. It uses algorithms " - "and statistical models to find patterns in data." - ) + + # Mock responses for the multi-step process: + # 1. extract-concepts extracts key concepts from the query + # 2. kg-edge-scoring scores edges for relevance + # 3. kg-edge-reasoning provides reasoning for selected edges + # 4. kg-synthesis returns the final answer + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "extract-concepts": + return "" # Falls back to raw query + elif prompt_name == "kg-edge-scoring": + return "" # No edges scored + elif prompt_name == "kg-edge-reasoning": + return "" # No reasoning + elif prompt_name == "kg-synthesis": + return ( + "Machine learning is a subset of artificial intelligence that enables computers " + "to learn from data without being explicitly programmed. It uses algorithms " + "and statistical models to find patterns in data." + ) + return "" + + client.prompt.side_effect = mock_prompt return client @pytest.fixture @@ -101,7 +125,7 @@ class TestGraphRagIntegration: async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client, mock_graph_embeddings_client, mock_triples_client, mock_prompt_client): - """Test complete GraphRAG pipeline from query to response""" + """Test complete GraphRAG pipeline from query to response with real-time provenance""" # Arrange query = "What is machine learning?" user = "test_user" @@ -109,41 +133,51 @@ class TestGraphRagIntegration: entity_limit = 50 triple_limit = 30 + # Collect provenance events + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + # Act - result = await graph_rag.query( + response = await graph_rag.query( query=query, user=user, collection=collection, entity_limit=entity_limit, - triple_limit=triple_limit + triple_limit=triple_limit, + explain_callback=collect_provenance ) # Assert - Verify service coordination - # 1. Should compute embeddings for query - mock_embeddings_client.embed.assert_called_once_with(query) + # 1. Should compute embeddings for query (now expects list of texts) + mock_embeddings_client.embed.assert_called_once_with([query]) # 2. Should query graph embeddings to find relevant entities mock_graph_embeddings_client.query.assert_called_once() call_args = mock_graph_embeddings_client.query.call_args - assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection # 3. Should query triples to build knowledge subgraph - assert mock_triples_client.query.call_count > 0 + assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt with knowledge graph - mock_prompt_client.kg_prompt.assert_called_once() - call_args = mock_prompt_client.kg_prompt.call_args - assert call_args.args[0] == query # First arg is query - assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples) + # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis) + assert mock_prompt_client.prompt.call_count == 4 # Verify final response - assert result is not None - assert isinstance(result, str) - assert "machine learning" in result.lower() + assert response is not None + assert isinstance(response, str) + assert "machine learning" in response.lower() + + # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis) + assert len(provenance_events) == 5 + for triples, prov_id in provenance_events: + assert isinstance(triples, list) + assert prov_id.startswith("urn:trustgraph:") @pytest.mark.asyncio async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client, @@ -197,21 +231,27 @@ class TestGraphRagIntegration: """Test GraphRAG handles empty knowledge graph gracefully""" # Arrange mock_graph_embeddings_client.query.return_value = [] # No entities found - mock_triples_client.query.return_value = [] # No triples found + mock_triples_client.query_stream.return_value = [] # No triples found + + # Collect provenance + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) # Act - result = await graph_rag.query( + response = await graph_rag.query( query="unknown topic", user="test_user", - collection="test_collection" + collection="test_collection", + explain_callback=collect_provenance ) # Assert - # Should still call prompt client with empty knowledge graph - mock_prompt_client.kg_prompt.assert_called_once() - call_args = mock_prompt_client.kg_prompt.call_args - assert isinstance(call_args.args[1], list) # kg should be a list - assert result is not None + # Should still call prompt client + assert response is not None + # Provenance should still be emitted (5 events) + assert len(provenance_events) == 5 @pytest.mark.asyncio async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client): @@ -226,7 +266,7 @@ class TestGraphRagIntegration: collection="test_collection" ) - first_call_count = mock_triples_client.query.call_count + first_call_count = mock_triples_client.query_stream.call_count mock_triples_client.reset_mock() # Second identical query @@ -236,7 +276,7 @@ class TestGraphRagIntegration: collection="test_collection" ) - second_call_count = mock_triples_client.query.call_count + second_call_count = mock_triples_client.query_stream.call_count # Assert - Second query should make fewer triple queries due to caching # Note: This is a weak assertion because caching behavior depends on diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 47dd84b6..b66c5289 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -8,6 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term, IRI from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -24,7 +25,8 @@ class TestGraphRagStreaming: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + # New batch format: [[[vectors_for_text1]]] + client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]] return client @pytest.fixture @@ -32,7 +34,7 @@ class TestGraphRagStreaming: """Mock graph embeddings client""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), ] return client @@ -51,30 +53,38 @@ class TestGraphRagStreaming: @pytest.fixture def mock_streaming_prompt_client(self, mock_streaming_llm_response): - """Mock prompt client with streaming support""" + """Mock prompt client with streaming support for two-stage GraphRAG""" client = AsyncMock() - async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): - # Both modes return the same text - full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." + # Full synthesis text + full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." - if streaming and chunk_callback: - # Simulate streaming chunks with end_of_stream flags - chunks = [] - async for chunk in mock_streaming_llm_response(): - chunks.append(chunk) + async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): + if prompt_id == "extract-concepts": + return "" # Falls back to raw query + elif prompt_id == "kg-edge-scoring": + # Edge scoring returns JSONL with IDs and scores + return '{"id": "abc12345", "score": 0.9}\n' + elif prompt_id == "kg-edge-reasoning": + return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' + elif prompt_id == "kg-synthesis": + if streaming and chunk_callback: + # Simulate streaming chunks with end_of_stream flags + chunks = [] + async for chunk in mock_streaming_llm_response(): + chunks.append(chunk) - # Send all chunks with end_of_stream=False except the last - for i, chunk in enumerate(chunks): - is_final = (i == len(chunks) - 1) - await chunk_callback(chunk, is_final) + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) - return full_text - else: - # Non-streaming response - same text - return full_text + return full_text + else: + return full_text + return "" - client.kg_prompt.side_effect = kg_prompt_side_effect + client.prompt.side_effect = prompt_side_effect return client @pytest.fixture @@ -91,18 +101,25 @@ class TestGraphRagStreaming: @pytest.mark.asyncio async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector): - """Test basic GraphRAG streaming functionality""" + """Test basic GraphRAG streaming functionality with real-time provenance""" # Arrange query = "What is machine learning?" collector = streaming_chunk_collector() - # Act - result = await graph_rag_streaming.query( + # Collect provenance events + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + + # Act - query() returns response, provenance via callback + response = await graph_rag_streaming.query( query=query, user="test_user", collection="test_collection", streaming=True, - chunk_callback=collector.collect + chunk_callback=collector.collect, + explain_callback=collect_provenance ) # Assert @@ -114,10 +131,15 @@ class TestGraphRagStreaming: # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() - assert result == full_from_chunks + assert response == full_from_chunks # Verify content is reasonable - assert "machine" in result.lower() or "learning" in result.lower() + assert "machine" in response.lower() or "learning" in response.lower() + + # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis) + assert len(provenance_events) == 5 + for triples, prov_id in provenance_events: + assert prov_id.startswith("urn:trustgraph:") @pytest.mark.asyncio async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming): @@ -128,7 +150,7 @@ class TestGraphRagStreaming: collection = "test_collection" # Act - Non-streaming - non_streaming_result = await graph_rag_streaming.query( + non_streaming_response = await graph_rag_streaming.query( query=query, user=user, collection=collection, @@ -141,7 +163,7 @@ class TestGraphRagStreaming: async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) - streaming_result = await graph_rag_streaming.query( + streaming_response = await graph_rag_streaming.query( query=query, user=user, collection=collection, @@ -150,9 +172,9 @@ class TestGraphRagStreaming: ) # Assert - Results should be equivalent - assert streaming_result == non_streaming_result + assert streaming_response == non_streaming_response assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_result + assert "".join(streaming_chunks) == streaming_response @pytest.mark.asyncio async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming): @@ -161,7 +183,7 @@ class TestGraphRagStreaming: callback = AsyncMock() # Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="test query", user="test_user", collection="test_collection", @@ -171,7 +193,7 @@ class TestGraphRagStreaming: # Assert assert callback.call_count > 0 - assert result is not None + assert response is not None # Verify all callback invocations had string arguments for call in callback.call_args_list: @@ -181,7 +203,7 @@ class TestGraphRagStreaming: async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming): """Test streaming parameter without callback (should fall back to non-streaming)""" # Arrange & Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="test query", user="test_user", collection="test_collection", @@ -190,8 +212,8 @@ class TestGraphRagStreaming: ) # Assert - Should complete without error - assert result is not None - assert isinstance(result, str) + assert response is not None + assert isinstance(response, str) @pytest.mark.asyncio async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming, @@ -202,7 +224,7 @@ class TestGraphRagStreaming: callback = AsyncMock() # Act - result = await graph_rag_streaming.query( + response = await graph_rag_streaming.query( query="unknown topic", user="test_user", collection="test_collection", @@ -211,7 +233,7 @@ class TestGraphRagStreaming: ) # Assert - Should still produce streamed response - assert result is not None + assert response is not None assert callback.call_count > 0 @pytest.mark.asyncio diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py index 13a851df..a3771b80 100644 --- a/tests/integration/test_import_export_graceful_shutdown.py +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend): triples_obj = Triples( metadata=Metadata( id=f"export-msg-{i}", - metadata=to_subgraph(msg_data["metadata"]["metadata"]), user=msg_data["metadata"]["user"], collection=msg_data["metadata"]["collection"], ), diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 2baa1d4d..4d8b60ad 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -17,7 +17,7 @@ from trustgraph.extract.kg.relationships.extract import Processor as Relationshi from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings -from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF +from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL @pytest.mark.integration @@ -92,7 +92,6 @@ class TestKnowledgeGraphPipelineIntegration: id="doc-123", user="test_user", collection="test_collection", - metadata=[] ), chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns." ) @@ -243,13 +242,12 @@ class TestKnowledgeGraphPipelineIntegration: id="test-doc", user="test_user", collection="test_collection", - metadata=[] ) # Act triples = [] entities = [] - + for defn in sample_definitions_response: s = defn["entity"] o = defn["definition"] @@ -302,12 +300,11 @@ class TestKnowledgeGraphPipelineIntegration: id="test-doc", user="test_user", collection="test_collection", - metadata=[] ) # Act triples = [] - + for rel in sample_relationships_response: s = rel["subject"] p = rel["predicate"] @@ -373,7 +370,6 @@ class TestKnowledgeGraphPipelineIntegration: id="test-doc", user="test_user", collection="test_collection", - metadata=[] ), triples=[ Triple( @@ -406,12 +402,11 @@ class TestKnowledgeGraphPipelineIntegration: id="test-doc", user="test_user", collection="test_collection", - metadata=[] ), entities=[ EntityEmbeddings( entity=Term(type=IRI, iri="http://example.org/entity"), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) ] ) @@ -542,7 +537,7 @@ class TestKnowledgeGraphPipelineIntegration: ] sample_chunk = Chunk( - metadata=Metadata(id="test", user="user", collection="collection", metadata=[]), + metadata=Metadata(id="test", user="user", collection="collection"), chunk=b"Test chunk" ) @@ -569,7 +564,7 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange large_chunk_batch = [ Chunk( - metadata=Metadata(id=f"doc-{i}", user="user", collection="collection", metadata=[]), + metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"), chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8") ) for i in range(100) # Large batch @@ -608,15 +603,8 @@ class TestKnowledgeGraphPipelineIntegration: id="test-doc-123", user="test_user", collection="test_collection", - metadata=[ - Triple( - s=Term(type=IRI, iri="doc:test"), - p=Term(type=IRI, iri="dc:title"), - o=Term(type=LITERAL, value="Test Document") - ) - ] ) - + sample_chunk = Chunk( metadata=original_metadata, chunk=b"Test content for metadata propagation" diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index dd48affe..faa63381 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -231,7 +231,6 @@ class TestObjectExtractionServiceIntegration: id="customer-doc-001", user="integration_test", collection="test_documents", - metadata=[] ) chunk_text = """ @@ -299,7 +298,6 @@ class TestObjectExtractionServiceIntegration: id="product-doc-001", user="integration_test", collection="test_documents", - metadata=[] ) chunk_text = """ @@ -373,7 +371,6 @@ class TestObjectExtractionServiceIntegration: id=chunk_id, user="concurrent_test", collection="test_collection", - metadata=[] ) chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8')) chunks.append(chunk) @@ -470,7 +467,7 @@ class TestObjectExtractionServiceIntegration: await processor.on_schema_config(integration_config, version=1) # Create test chunk - metadata = Metadata(id="error-test", user="test", collection="test", metadata=[]) + metadata = Metadata(id="error-test", user="test", collection="test") chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process") mock_msg = MagicMock() @@ -507,7 +504,6 @@ class TestObjectExtractionServiceIntegration: id="metadata-test-chunk", user="test_user", collection="test_collection", - metadata=[] # Could include source document metadata ) chunk = Chunk( diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index d2ceea95..f5fe14b5 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI class TestGraphRagStreamingProtocol: @@ -18,14 +19,17 @@ class TestGraphRagStreamingProtocol: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3]] + client.embed.return_value = [[[0.1, 0.2, 0.3]]] return client @pytest.fixture def mock_graph_embeddings_client(self): """Mock graph embeddings client""" client = AsyncMock() - client.query.return_value = ["entity1", "entity2"] + client.query.return_value = [ + EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90) + ] return client @pytest.fixture @@ -40,18 +44,23 @@ class TestGraphRagStreamingProtocol: """Mock prompt client that simulates realistic streaming with end_of_stream flags""" client = AsyncMock() - async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): - if streaming and chunk_callback: - # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True - await chunk_callback("The", False) - await chunk_callback(" answer", False) - await chunk_callback(" is here.", False) - await chunk_callback("", True) # Empty final chunk with end_of_stream=True - return "" # Return value not used since callback handles everything - else: - return "The answer is here." + async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": + # Edge selection returns empty (no edges selected) + return "" + elif prompt_name == "kg-synthesis": + if streaming and chunk_callback: + # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True + await chunk_callback("The", False) + await chunk_callback(" answer", False) + await chunk_callback(" is here.", False) + await chunk_callback("", True) # Empty final chunk with end_of_stream=True + return "" # Return value not used since callback handles everything + else: + return "The answer is here." + return "" - client.kg_prompt.side_effect = kg_prompt_side_effect + client.prompt.side_effect = prompt_side_effect return client @pytest.fixture @@ -197,16 +206,26 @@ class TestDocumentRagStreamingProtocol: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3]] + client.embed.return_value = [[[0.1, 0.2, 0.3]]] return client @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - client.query.return_value = ["doc1", "doc2"] + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90) + ] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return f"Content for {chunk_id}" + return fetch + @pytest.fixture def mock_streaming_prompt_client(self): """Mock prompt client with streaming support""" @@ -227,12 +246,13 @@ class TestDocumentRagStreamingProtocol: @pytest.fixture def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, - mock_streaming_prompt_client): + mock_streaming_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with mocked dependencies""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_streaming_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) @@ -312,20 +332,24 @@ class TestStreamingProtocolEdgeCases: # Arrange client = AsyncMock() - async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None): - if streaming and chunk_callback: - await chunk_callback("text", False) - await chunk_callback("", False) # Empty but not final - await chunk_callback("more", False) - await chunk_callback("", True) # Empty and final + async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "kg-edge-selection": return "" - else: - return "textmore" + elif prompt_name == "kg-synthesis": + if streaming and chunk_callback: + await chunk_callback("text", False) + await chunk_callback("", False) # Empty but not final + await chunk_callback("more", False) + await chunk_callback("", True) # Empty and final + return "" + else: + return "textmore" + return "" - client.kg_prompt.side_effect = kg_prompt_with_empties + client.prompt.side_effect = prompt_with_empties rag = GraphRag( - embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])), + embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])), graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), triples_client=AsyncMock(query=AsyncMock(return_value=[])), prompt_client=client, diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index 2cb973a7..9067816a 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -120,7 +120,6 @@ class TestRowsCassandraIntegration: id="doc-001", user="test_user", collection="import_2024", - metadata=[] ), schema_name="customer_records", values=[{ @@ -201,7 +200,7 @@ class TestRowsCassandraIntegration: # Process objects for different schemas product_obj = ExtractedObject( - metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), + metadata=Metadata(id="p1", user="shop", collection="catalog"), schema_name="products", values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], confidence=0.9, @@ -209,7 +208,7 @@ class TestRowsCassandraIntegration: ) order_obj = ExtractedObject( - metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), + metadata=Metadata(id="o1", user="shop", collection="sales"), schema_name="orders", values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], confidence=0.85, @@ -254,7 +253,7 @@ class TestRowsCassandraIntegration: ) test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + metadata=Metadata(id="t1", user="test", collection="test"), schema_name="indexed_data", values=[{ "id": "123", @@ -337,7 +336,6 @@ class TestRowsCassandraIntegration: id="batch-001", user="test_user", collection="batch_import", - metadata=[] ), schema_name="batch_customers", values=[ @@ -391,7 +389,7 @@ class TestRowsCassandraIntegration: # Process empty batch object empty_obj = ExtractedObject( - metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), + metadata=Metadata(id="empty-1", user="test", collection="empty"), schema_name="empty_test", values=[], # Empty batch confidence=1.0, @@ -426,7 +424,7 @@ class TestRowsCassandraIntegration: ) test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + metadata=Metadata(id="t1", user="test", collection="test"), schema_name="map_test", values=[{"id": "123", "name": "Test Item", "count": "42"}], confidence=0.9, @@ -470,7 +468,7 @@ class TestRowsCassandraIntegration: ) test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]), + metadata=Metadata(id="t1", user="test", collection="my_collection"), schema_name="partition_test", values=[{"id": "123", "category": "test"}], confidence=0.9, diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index 0fd2060d..2ef64e96 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] @@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] @@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming: max_iterations=10 ) + # Mock librarian to avoid hanging on save operations + processor.save_answer_content = AsyncMock(return_value=None) + # Track all responses sent sent_responses = [] diff --git a/tests/unit/test_agent/test_tool_service.py b/tests/unit/test_agent/test_tool_service.py new file mode 100644 index 00000000..8bcf39ce --- /dev/null +++ b/tests/unit/test_agent/test_tool_service.py @@ -0,0 +1,495 @@ +""" +Unit tests for Tool Service functionality + +Tests the dynamically pluggable tool services feature including: +- Tool service configuration parsing +- ToolServiceImpl initialization +- Request/response format +- Config parameter handling +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +import json + + +class TestToolServiceConfigParsing: + """Test cases for tool service configuration parsing""" + + def test_tool_service_config_structure(self): + """Test that tool-service config has required fields""" + # Arrange + valid_config = { + "id": "joke-service", + "request-queue": "non-persistent://tg/request/joke", + "response-queue": "non-persistent://tg/response/joke", + "config-params": [ + {"name": "style", "required": False} + ] + } + + # Act & Assert + assert "id" in valid_config + assert "request-queue" in valid_config + assert "response-queue" in valid_config + assert valid_config["request-queue"].startswith("non-persistent://") + assert valid_config["response-queue"].startswith("non-persistent://") + + def test_tool_service_config_without_queues_is_invalid(self): + """Test that tool-service config requires request-queue and response-queue""" + # Arrange + invalid_config = { + "id": "joke-service", + "config-params": [] + } + + # Act & Assert + def validate_config(config): + request_queue = config.get("request-queue") + response_queue = config.get("response-queue") + if not request_queue or not response_queue: + raise RuntimeError("Tool-service must define 'request-queue' and 'response-queue'") + return True + + with pytest.raises(RuntimeError) as exc_info: + validate_config(invalid_config) + assert "request-queue" in str(exc_info.value) + + def test_tool_config_references_tool_service(self): + """Test that tool config correctly references a tool-service""" + # Arrange + tool_services = { + "joke-service": { + "id": "joke-service", + "request-queue": "non-persistent://tg/request/joke", + "response-queue": "non-persistent://tg/response/joke", + "config-params": [{"name": "style", "required": False}] + } + } + + tool_config = { + "type": "tool-service", + "name": "tell-joke", + "description": "Tell a joke on a given topic", + "service": "joke-service", + "style": "pun", + "arguments": [ + {"name": "topic", "type": "string", "description": "The topic for the joke"} + ] + } + + # Act + service_ref = tool_config.get("service") + service_config = tool_services.get(service_ref) + + # Assert + assert service_ref == "joke-service" + assert service_config is not None + assert service_config["request-queue"] == "non-persistent://tg/request/joke" + + def test_tool_config_extracts_config_values(self): + """Test that config values are extracted from tool config""" + # Arrange + tool_services = { + "joke-service": { + "id": "joke-service", + "request-queue": "non-persistent://tg/request/joke", + "response-queue": "non-persistent://tg/response/joke", + "config-params": [ + {"name": "style", "required": False}, + {"name": "max-length", "required": False} + ] + } + } + + tool_config = { + "type": "tool-service", + "name": "tell-joke", + "description": "Tell a joke", + "service": "joke-service", + "style": "pun", + "max-length": 100, + "arguments": [] + } + + # Act - simulate config extraction + service_config = tool_services[tool_config["service"]] + config_params = service_config.get("config-params", []) + config_values = {} + for param in config_params: + param_name = param.get("name") if isinstance(param, dict) else param + if param_name in tool_config: + config_values[param_name] = tool_config[param_name] + + # Assert + assert config_values == {"style": "pun", "max-length": 100} + + def test_required_config_param_validation(self): + """Test that required config params are validated""" + # Arrange + tool_services = { + "custom-service": { + "id": "custom-service", + "request-queue": "non-persistent://tg/request/custom", + "response-queue": "non-persistent://tg/response/custom", + "config-params": [ + {"name": "collection", "required": True}, + {"name": "optional-param", "required": False} + ] + } + } + + tool_config_missing_required = { + "type": "tool-service", + "name": "custom-tool", + "description": "Custom tool", + "service": "custom-service", + # Missing required "collection" param + "optional-param": "value" + } + + # Act & Assert + def validate_config_params(tool_config, service_config): + config_params = service_config.get("config-params", []) + for param in config_params: + param_name = param.get("name") + if param.get("required", False) and param_name not in tool_config: + raise RuntimeError(f"Missing required config param '{param_name}'") + return True + + service_config = tool_services["custom-service"] + with pytest.raises(RuntimeError) as exc_info: + validate_config_params(tool_config_missing_required, service_config) + assert "collection" in str(exc_info.value) + + +class TestToolServiceRequest: + """Test cases for tool service request format""" + + def test_request_format(self): + """Test that request is properly formatted with user, config, and arguments""" + # Arrange + user = "alice" + config_values = {"style": "pun", "collection": "jokes"} + arguments = {"topic": "programming"} + + # Act - simulate request building + request = { + "user": user, + "config": json.dumps(config_values), + "arguments": json.dumps(arguments) + } + + # Assert + assert request["user"] == "alice" + assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"} + assert json.loads(request["arguments"]) == {"topic": "programming"} + + def test_request_with_empty_config(self): + """Test request when no config values are provided""" + # Arrange + user = "bob" + config_values = {} + arguments = {"query": "test"} + + # Act + request = { + "user": user, + "config": json.dumps(config_values) if config_values else "{}", + "arguments": json.dumps(arguments) if arguments else "{}" + } + + # Assert + assert request["config"] == "{}" + assert json.loads(request["arguments"]) == {"query": "test"} + + +class TestToolServiceResponse: + """Test cases for tool service response handling""" + + def test_success_response_handling(self): + """Test handling of successful response""" + # Arrange + response = { + "error": None, + "response": "Hey alice! Here's a pun for you:\n\nWhy do programmers prefer dark mode?", + "end_of_stream": True + } + + # Act & Assert + assert response["error"] is None + assert "pun" in response["response"] + assert response["end_of_stream"] is True + + def test_error_response_handling(self): + """Test handling of error response""" + # Arrange + response = { + "error": { + "type": "tool-service-error", + "message": "Service unavailable" + }, + "response": "", + "end_of_stream": True + } + + # Act & Assert + assert response["error"] is not None + assert response["error"]["type"] == "tool-service-error" + assert response["error"]["message"] == "Service unavailable" + + def test_string_response_passthrough(self): + """Test that string responses are passed through directly""" + # Arrange + response_text = "This is a joke response" + + # Act - simulate response handling + def handle_response(response): + if isinstance(response, str): + return response + else: + return json.dumps(response) + + result = handle_response(response_text) + + # Assert + assert result == response_text + + def test_dict_response_json_serialization(self): + """Test that dict responses are JSON serialized""" + # Arrange + response_data = {"joke": "Why did the chicken cross the road?", "category": "classic"} + + # Act + def handle_response(response): + if isinstance(response, str): + return response + else: + return json.dumps(response) + + result = handle_response(response_data) + + # Assert + assert result == json.dumps(response_data) + assert json.loads(result) == response_data + + +class TestToolServiceImpl: + """Test cases for ToolServiceImpl class""" + + def test_tool_service_impl_initialization(self): + """Test ToolServiceImpl stores queues and config correctly""" + # Arrange + class MockArgument: + def __init__(self, name, type, description): + self.name = name + self.type = type + self.description = description + + # Simulate ToolServiceImpl initialization + class MockToolServiceImpl: + def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None): + self.context = context + self.request_queue = request_queue + self.response_queue = response_queue + self.config_values = config_values or {} + self.arguments = arguments or [] + self.processor = processor + self._client = None + + def get_arguments(self): + return self.arguments + + # Act + arguments = [ + MockArgument("topic", "string", "The topic for the joke") + ] + + impl = MockToolServiceImpl( + context=lambda x: None, + request_queue="non-persistent://tg/request/joke", + response_queue="non-persistent://tg/response/joke", + config_values={"style": "pun"}, + arguments=arguments, + processor=Mock() + ) + + # Assert + assert impl.request_queue == "non-persistent://tg/request/joke" + assert impl.response_queue == "non-persistent://tg/response/joke" + assert impl.config_values == {"style": "pun"} + assert len(impl.get_arguments()) == 1 + assert impl.get_arguments()[0].name == "topic" + + def test_tool_service_impl_client_caching(self): + """Test that client is cached and reused""" + # Arrange + client_key = "non-persistent://tg/request/joke|non-persistent://tg/response/joke" + + # Simulate client caching behavior + tool_service_clients = {} + + def get_or_create_client(request_queue, response_queue, clients_cache): + client_key = f"{request_queue}|{response_queue}" + if client_key in clients_cache: + return clients_cache[client_key], False # False = not created + client = Mock() + clients_cache[client_key] = client + return client, True # True = newly created + + # Act + client1, created1 = get_or_create_client( + "non-persistent://tg/request/joke", + "non-persistent://tg/response/joke", + tool_service_clients + ) + client2, created2 = get_or_create_client( + "non-persistent://tg/request/joke", + "non-persistent://tg/response/joke", + tool_service_clients + ) + + # Assert + assert created1 is True + assert created2 is False + assert client1 is client2 + + +class TestJokeServiceLogic: + """Test cases for the joke service example""" + + def test_topic_to_category_mapping(self): + """Test that topics are mapped to categories correctly""" + # Arrange + def map_topic_to_category(topic): + topic = topic.lower() + if "program" in topic or "code" in topic or "computer" in topic or "software" in topic: + return "programming" + elif "llama" in topic: + return "llama" + elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic: + return "animals" + elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic: + return "food" + else: + return "default" + + # Act & Assert + assert map_topic_to_category("programming") == "programming" + assert map_topic_to_category("software engineering") == "programming" + assert map_topic_to_category("llamas") == "llama" + assert map_topic_to_category("llama") == "llama" + assert map_topic_to_category("animals") == "animals" + assert map_topic_to_category("my dog") == "animals" + assert map_topic_to_category("food") == "food" + assert map_topic_to_category("cooking recipes") == "food" + assert map_topic_to_category("random topic") == "default" + assert map_topic_to_category("") == "default" + + def test_joke_response_personalization(self): + """Test that joke responses include user personalization""" + # Arrange + user = "alice" + style = "pun" + joke = "Why do programmers prefer dark mode? Because light attracts bugs!" + + # Act + response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + + # Assert + assert "Hey alice!" in response + assert "pun" in response + assert joke in response + + def test_style_normalization(self): + """Test that invalid styles fall back to valid ones""" + import random + + # Arrange + valid_styles = ["pun", "dad-joke", "one-liner"] + + def normalize_style(style): + if style not in valid_styles: + return random.choice(valid_styles) + return style + + # Act & Assert + assert normalize_style("pun") == "pun" + assert normalize_style("dad-joke") == "dad-joke" + assert normalize_style("one-liner") == "one-liner" + assert normalize_style("invalid-style") in valid_styles + assert normalize_style("") in valid_styles + + +class TestDynamicToolServiceBase: + """Test cases for DynamicToolService base class behavior""" + + def test_topic_to_pulsar_path_conversion(self): + """Test that topic names are converted to full Pulsar paths""" + # Arrange + topic = "joke" + + # Act + request_topic = f"non-persistent://tg/request/{topic}" + response_topic = f"non-persistent://tg/response/{topic}" + + # Assert + assert request_topic == "non-persistent://tg/request/joke" + assert response_topic == "non-persistent://tg/response/joke" + + def test_request_parsing(self): + """Test parsing of incoming request""" + # Arrange + request_data = { + "user": "alice", + "config": '{"style": "pun"}', + "arguments": '{"topic": "programming"}' + } + + # Act + user = request_data.get("user", "trustgraph") + config = json.loads(request_data["config"]) if request_data["config"] else {} + arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {} + + # Assert + assert user == "alice" + assert config == {"style": "pun"} + assert arguments == {"topic": "programming"} + + def test_response_building(self): + """Test building of response message""" + # Arrange + response_text = "Hey alice! Here's a joke" + error = None + + # Act + response = { + "error": error, + "response": response_text if isinstance(response_text, str) else json.dumps(response_text), + "end_of_stream": True + } + + # Assert + assert response["error"] is None + assert response["response"] == "Hey alice! Here's a joke" + assert response["end_of_stream"] is True + + def test_error_response_building(self): + """Test building of error response""" + # Arrange + error_message = "Service temporarily unavailable" + + # Act + response = { + "error": { + "type": "tool-service-error", + "message": error_message + }, + "response": "", + "end_of_stream": True + } + + # Assert + assert response["error"]["type"] == "tool-service-error" + assert response["error"]["message"] == error_message + assert response["response"] == "" diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py new file mode 100644 index 00000000..65cdb542 --- /dev/null +++ b/tests/unit/test_agent/test_tool_service_lifecycle.py @@ -0,0 +1,624 @@ +""" +Tests for tool service lifecycle, invoke contract, streaming responses, +multi-tenancy, and error propagation. + +Tests the actual DynamicToolService, ToolService, and ToolServiceClient +classes rather than plain dicts. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.schema import ( + ToolServiceRequest, ToolServiceResponse, Error, + ToolRequest, ToolResponse, +) +from trustgraph.exceptions import TooManyRequests + + +# --------------------------------------------------------------------------- +# DynamicToolService tests +# --------------------------------------------------------------------------- + +class TestDynamicToolServiceInvokeContract: + + @pytest.mark.asyncio + async def test_base_invoke_raises_not_implemented(self): + """Base class invoke() should raise NotImplementedError.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + + with pytest.raises(NotImplementedError): + await svc.invoke("user", {}, {}) + + @pytest.mark.asyncio + async def test_on_request_calls_invoke_with_parsed_args(self): + """on_request should JSON-parse config/arguments and pass to invoke.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + calls = [] + + async def tracking_invoke(user, config, arguments): + calls.append({"user": user, "config": config, "arguments": arguments}) + return "ok" + + svc.invoke = tracking_invoke + + # Ensure the class-level metric exists + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest( + user="alice", + config='{"style": "pun"}', + arguments='{"topic": "cats"}', + ) + msg.properties.return_value = {"id": "req-1"} + + await svc.on_request(msg, MagicMock(), None) + + assert len(calls) == 1 + assert calls[0]["user"] == "alice" + assert calls[0]["config"] == {"style": "pun"} + assert calls[0]["arguments"] == {"topic": "cats"} + + @pytest.mark.asyncio + async def test_on_request_empty_user_defaults_to_trustgraph(self): + """Empty user field should default to 'trustgraph'.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + received_user = None + + async def capture_invoke(user, config, arguments): + nonlocal received_user + received_user = user + return "ok" + + svc.invoke = capture_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="", config="", arguments="") + msg.properties.return_value = {"id": "req-2"} + + await svc.on_request(msg, MagicMock(), None) + + assert received_user == "trustgraph" + + @pytest.mark.asyncio + async def test_on_request_string_response_sent_directly(self): + """String return from invoke → response field is the string.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def string_invoke(user, config, arguments): + return "hello world" + + svc.invoke = string_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r1"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert isinstance(sent, ToolServiceResponse) + assert sent.response == "hello world" + assert sent.end_of_stream is True + assert sent.error is None + + @pytest.mark.asyncio + async def test_on_request_dict_response_json_encoded(self): + """Dict return from invoke → response field is JSON-encoded.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def dict_invoke(user, config, arguments): + return {"result": 42} + + svc.invoke = dict_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r2"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert json.loads(sent.response) == {"result": 42} + + @pytest.mark.asyncio + async def test_on_request_error_sends_error_response(self): + """Exception in invoke → error response sent.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def failing_invoke(user, config, arguments): + raise ValueError("bad input") + + svc.invoke = failing_invoke + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r3"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "tool-service-error" + assert "bad input" in sent.error.message + assert sent.response == "" + + @pytest.mark.asyncio + async def test_on_request_too_many_requests_propagates(self): + """TooManyRequests should propagate (not caught as error response).""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def rate_limited_invoke(user, config, arguments): + raise TooManyRequests("rate limited") + + svc.invoke = rate_limited_invoke + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r4"} + + with pytest.raises(TooManyRequests): + await svc.on_request(msg, MagicMock(), None) + + @pytest.mark.asyncio + async def test_on_request_preserves_message_id(self): + """Response should include the original message id in properties.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def ok_invoke(user, config, arguments): + return "ok" + + svc.invoke = ok_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "unique-42"} + + await svc.on_request(msg, MagicMock(), None) + + props = svc.producer.send.call_args[1]["properties"] + assert props["id"] == "unique-42" + + +# --------------------------------------------------------------------------- +# ToolService (flow-based) tests +# --------------------------------------------------------------------------- + +class TestToolServiceOnRequest: + + @pytest.mark.asyncio + async def test_string_response_sent_as_text(self): + """String return from invoke_tool → ToolResponse.text is set.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def mock_invoke(name, params): + return "tool result" + + svc.invoke_tool = mock_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_response_pub = AsyncMock() + flow = MagicMock() + flow.name = "test-flow" + + def flow_callable(name): + if name == "response": + return mock_response_pub + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') + msg.properties.return_value = {"id": "t1"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert isinstance(sent, ToolResponse) + assert sent.text == "tool result" + assert sent.object is None + + @pytest.mark.asyncio + async def test_dict_response_sent_as_json_object(self): + """Dict return from invoke_tool → ToolResponse.object is JSON.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def mock_invoke(name, params): + return {"data": [1, 2, 3]} + + svc.invoke_tool = mock_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_response_pub = AsyncMock() + flow = MagicMock() + + def flow_callable(name): + if name == "response": + return mock_response_pub + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t2"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert sent.text is None + assert json.loads(sent.object) == {"data": [1, 2, 3]} + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + """Exception in invoke_tool → error response via flow producer.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def failing_invoke(name, params): + raise RuntimeError("tool broke") + + svc.invoke_tool = failing_invoke + + mock_response_pub = AsyncMock() + flow = MagicMock() + + def flow_callable(name): + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t3"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "tool-error" + assert "tool broke" in sent.error.message + + @pytest.mark.asyncio + async def test_too_many_requests_propagates(self): + """TooManyRequests should propagate from ToolService.on_request.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def rate_limited(name, params): + raise TooManyRequests("slow down") + + svc.invoke_tool = rate_limited + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t4"} + + flow = MagicMock() + flow.producer = {"response": AsyncMock()} + flow.name = "test-flow" + + with pytest.raises(TooManyRequests): + await svc.on_request(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_parameters_json_parsed(self): + """Parameters should be JSON-parsed before passing to invoke_tool.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + received = {} + + async def capture_invoke(name, params): + received["name"] = name + received["params"] = params + return "ok" + + svc.invoke_tool = capture_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + flow.producer = {"response": mock_pub} + flow.name = "f" + + msg = MagicMock() + msg.value.return_value = ToolRequest( + name="search", + parameters='{"query": "test", "limit": 10}', + ) + msg.properties.return_value = {"id": "t5"} + + await svc.on_request(msg, MagicMock(), flow) + + assert received["name"] == "search" + assert received["params"] == {"query": "test", "limit": 10} + + +# --------------------------------------------------------------------------- +# ToolServiceClient tests +# --------------------------------------------------------------------------- + +class TestToolServiceClientCall: + + @pytest.mark.asyncio + async def test_call_sends_request_and_returns_response(self): + """call() should send ToolServiceRequest and return response string.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="joke result", end_of_stream=True, + )) + + result = await client.call( + user="alice", + config={"style": "pun"}, + arguments={"topic": "cats"}, + ) + + assert result == "joke result" + + req = client.request.call_args[0][0] + assert isinstance(req, ToolServiceRequest) + assert req.user == "alice" + assert json.loads(req.config) == {"style": "pun"} + assert json.loads(req.arguments) == {"topic": "cats"} + + @pytest.mark.asyncio + async def test_call_raises_on_error(self): + """call() should raise RuntimeError when response has error.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=Error(type="tool-service-error", message="service down"), + response="", + )) + + with pytest.raises(RuntimeError, match="service down"): + await client.call(user="u", config={}, arguments={}) + + @pytest.mark.asyncio + async def test_call_empty_config_sends_empty_json(self): + """Empty config/arguments should be sent as '{}'.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="u", config=None, arguments=None) + + req = client.request.call_args[0][0] + assert req.config == "{}" + assert req.arguments == "{}" + + @pytest.mark.asyncio + async def test_call_passes_timeout(self): + """call() should forward timeout to underlying request.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="u", config={}, arguments={}, timeout=30) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 30 + + +class TestToolServiceClientStreaming: + + @pytest.mark.asyncio + async def test_call_streaming_collects_chunks(self): + """call_streaming should accumulate chunks and return full result.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + # Simulate streaming: request() calls recipient with each chunk + chunks = [ + ToolServiceResponse(error=None, response="chunk1", end_of_stream=False), + ToolServiceResponse(error=None, response="chunk2", end_of_stream=True), + ] + + async def mock_request(req, timeout=600, recipient=None): + for chunk in chunks: + done = await recipient(chunk) + if done: + break + + client.request = mock_request + + received = [] + + async def callback(text): + received.append(text) + + result = await client.call_streaming( + user="u", config={}, arguments={}, callback=callback, + ) + + assert result == "chunk1chunk2" + assert received == ["chunk1", "chunk2"] + + @pytest.mark.asyncio + async def test_call_streaming_raises_on_error(self): + """call_streaming should raise RuntimeError on error chunk.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + async def mock_request(req, timeout=600, recipient=None): + error_resp = ToolServiceResponse( + error=Error(type="tool-service-error", message="stream failed"), + response="", + end_of_stream=True, + ) + await recipient(error_resp) + + client.request = mock_request + + with pytest.raises(RuntimeError, match="stream failed"): + await client.call_streaming( + user="u", config={}, arguments={}, + callback=AsyncMock(), + ) + + @pytest.mark.asyncio + async def test_call_streaming_skips_empty_response(self): + """Empty response chunks should not be added to result.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + chunks = [ + ToolServiceResponse(error=None, response="", end_of_stream=False), + ToolServiceResponse(error=None, response="data", end_of_stream=True), + ] + + async def mock_request(req, timeout=600, recipient=None): + for chunk in chunks: + done = await recipient(chunk) + if done: + break + + client.request = mock_request + + received = [] + + async def callback(text): + received.append(text) + + result = await client.call_streaming( + user="u", config={}, arguments={}, callback=callback, + ) + + # Empty response is falsy, so callback shouldn't be called for it + assert result == "data" + assert received == ["data"] + + +# --------------------------------------------------------------------------- +# Multi-tenancy +# --------------------------------------------------------------------------- + +class TestMultiTenancy: + + @pytest.mark.asyncio + async def test_user_propagated_to_invoke(self): + """User from request should reach the invoke method.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test" + svc.producer = AsyncMock() + + users_seen = [] + + async def tracking(user, config, arguments): + users_seen.append(user) + return "ok" + + svc.invoke = tracking + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + for u in ["tenant-a", "tenant-b", "tenant-c"]: + msg = MagicMock() + msg.value.return_value = ToolServiceRequest( + user=u, config="{}", arguments="{}", + ) + msg.properties.return_value = {"id": f"req-{u}"} + await svc.on_request(msg, MagicMock(), None) + + assert users_seen == ["tenant-a", "tenant-b", "tenant-c"] + + @pytest.mark.asyncio + async def test_client_sends_user_in_request(self): + """ToolServiceClient.call should include user in request.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="isolated-tenant", config={}, arguments={}) + + req = client.request.call_args[0][0] + assert req.user == "isolated-tenant" diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 1c91408d..705f2bd1 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -23,27 +23,27 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None mock_response.chunks = ["chunk1", "chunk2", "chunk3"] - + # Mock the request method client.request = AsyncMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = await client.query( - vectors=vectors, + vector=vector, limit=10, user="test_user", collection="test_collection", timeout=30 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.request.assert_called_once() call_args = client.request.call_args[0][0] assert isinstance(call_args, DocumentEmbeddingsRequest) - assert call_args.vectors == vectors + assert call_args.vector == vector assert call_args.limit == 10 assert call_args.user == "test_user" assert call_args.collection == "test_collection" @@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(RuntimeError, match="Database connection failed"): await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -76,12 +76,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None mock_response.chunks = [] - + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -94,11 +94,11 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None mock_response.chunks = ["test_chunk"] - + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + result = await client.query(vector=[0.1, 0.2, 0.3]) # Assert client.request.assert_called_once() @@ -116,15 +116,15 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None mock_response.chunks = ["chunk1"] - + client.request = AsyncMock(return_value=mock_response) - + # Act await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=60 ) - + # Assert assert client.request.call_args[1]["timeout"] == 60 @@ -137,13 +137,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None mock_response.chunks = ["test_chunk"] - + client.request = AsyncMock(return_value=mock_response) - + # Act with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger: - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert mock_logger.debug.assert_called_once() assert "Document embeddings response" in str(mock_logger.debug.call_args) diff --git a/tests/unit/test_chunking/conftest.py b/tests/unit/test_chunking/conftest.py index c01f73d8..31dab77d 100644 --- a/tests/unit/test_chunking/conftest.py +++ b/tests/unit/test_chunking/conftest.py @@ -28,7 +28,6 @@ def sample_text_document(): """Sample document with moderate length text.""" metadata = Metadata( id="test-doc-1", - metadata=[], user="test-user", collection="test-collection" ) @@ -44,7 +43,6 @@ def long_text_document(): """Long document for testing multiple chunks.""" metadata = Metadata( id="test-doc-long", - metadata=[], user="test-user", collection="test-collection" ) @@ -61,7 +59,6 @@ def unicode_text_document(): """Document with various unicode characters.""" metadata = Metadata( id="test-doc-unicode", - metadata=[], user="test-user", collection="test-collection" ) @@ -87,7 +84,6 @@ def empty_text_document(): """Empty document for edge case testing.""" metadata = Metadata( id="test-doc-empty", - metadata=[], user="test-user", collection="test-collection" ) diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index 8f91d95f..ae05d22c 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -17,13 +17,17 @@ class MockAsyncProcessor: self.config_handlers = [] self.id = params.get('id', 'test-service') self.specifications = [] + self.pubsub = MagicMock() + self.taskgroup = params.get('taskgroup', MagicMock()) class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): """Test Recursive chunker functionality""" + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - def test_processor_initialization_basic(self): + def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" # Arrange config = { @@ -47,8 +51,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_chunk_size_override(self): + async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" # Arrange config = { @@ -79,8 +85,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 2000 # Should use overridden value assert chunk_overlap == 100 # Should use default value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_chunk_overlap_override(self): + async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" # Arrange config = { @@ -111,8 +119,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1000 # Should use default value assert chunk_overlap == 200 # Should use overridden value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_both_parameters_override(self): + async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" # Arrange config = { @@ -143,9 +153,11 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1500 # Should use overridden value assert chunk_overlap == 150 # Should use overridden value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_uses_flow_parameters(self, mock_splitter_class): + async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): """Test that on_message method uses parameters from flow""" # Arrange mock_splitter = MagicMock() @@ -164,26 +176,31 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) + # Mock save_child_document to avoid waiting for librarian response + processor.save_child_document = AsyncMock(return_value="mock-doc-id") + # Mock message with TextDocument mock_message = MagicMock() mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-123", - metadata=[], user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content" + mock_text_doc.document_id = "" # No librarian fetch needed mock_message.value.return_value = mock_text_doc # Mock consumer and flow with parameter overrides mock_consumer = MagicMock() mock_producer = AsyncMock() + mock_triples_producer = AsyncMock() mock_flow = MagicMock() mock_flow.side_effect = lambda param: { "chunk-size": 1500, "chunk-overlap": 150, - "output": mock_producer + "output": mock_producer, + "triples": mock_triples_producer, }.get(param) # Act @@ -202,8 +219,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_no_overrides(self): + async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" # Arrange config = { diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index 600df930..2ed37391 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -17,13 +17,17 @@ class MockAsyncProcessor: self.config_handlers = [] self.id = params.get('id', 'test-service') self.specifications = [] + self.pubsub = MagicMock() + self.taskgroup = params.get('taskgroup', MagicMock()) class TestTokenChunkerSimple(IsolatedAsyncioTestCase): """Test Token chunker functionality""" + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - def test_processor_initialization_basic(self): + def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" # Arrange config = { @@ -47,8 +51,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_chunk_size_override(self): + async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" # Arrange config = { @@ -79,8 +85,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 400 # Should use overridden value assert chunk_overlap == 15 # Should use default value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_chunk_overlap_override(self): + async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" # Arrange config = { @@ -111,8 +119,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 25 # Should use overridden value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_both_parameters_override(self): + async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" # Arrange config = { @@ -143,9 +153,11 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 350 # Should use overridden value assert chunk_overlap == 30 # Should use overridden value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.chunking.token.chunker.TokenTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_uses_flow_parameters(self, mock_splitter_class): + async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): """Test that on_message method uses parameters from flow""" # Arrange mock_splitter = MagicMock() @@ -164,26 +176,31 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) + # Mock save_child_document to avoid librarian producer interactions + processor.save_child_document = AsyncMock(return_value="chunk-id") + # Mock message with TextDocument mock_message = MagicMock() mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-456", - metadata=[], user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content for token chunking" + mock_text_doc.document_id = "" # No librarian fetch needed mock_message.value.return_value = mock_text_doc # Mock consumer and flow with parameter overrides mock_consumer = MagicMock() mock_producer = AsyncMock() + mock_triples_producer = AsyncMock() mock_flow = MagicMock() mock_flow.side_effect = lambda param: { "chunk-size": 400, "chunk-overlap": 40, - "output": mock_producer + "output": mock_producer, + "triples": mock_triples_producer, }.get(param) # Act @@ -206,8 +223,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_chunk_document_with_no_overrides(self): + async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" # Arrange config = { @@ -235,8 +254,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 15 # Should use default value + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - def test_token_chunker_uses_different_defaults(self): + def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer): """Test that token chunker has different defaults than recursive chunker""" # Arrange & Act config = { diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index 5873d81c..ce758f66 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1", "chunk2", "chunk3"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = client.request( - vectors=vectors, + vector=vector, user="test_user", collection="test_collection", limit=10, timeout=300 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.call.assert_called_once_with( user="test_user", collection="test_collection", - vectors=vectors, + vector=vector, limit=10, timeout=300 ) @@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["test_chunk"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3]] - + + vector = [0.1, 0.2, 0.3] + # Act - result = client.request(vectors=vectors) - + result = client.request(vector=vector) + # Assert assert result == ["test_chunk"] client.call.assert_called_once_with( user="trustgraph", collection="default", - vectors=vectors, + vector=vector, limit=10, timeout=300 ) @@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = [] client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = None client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result is None @@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1"] client.call = MagicMock(return_value=mock_response) - + # Act client.request( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=600 ) - + # Assert assert client.call.call_args[1]["timeout"] == 600 \ No newline at end of file diff --git a/tests/unit/test_concurrency/__init__.py b/tests/unit/test_concurrency/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_concurrency/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py new file mode 100644 index 00000000..32a6559b --- /dev/null +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -0,0 +1,286 @@ +""" +Tests for Consumer concurrency: TaskGroup-based concurrent message processing, +rate-limit retry with backpressure, and message acknowledgement. +""" + +import asyncio +import time + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.base.consumer import Consumer +from trustgraph.exceptions import TooManyRequests + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_consumer( + concurrency=1, + handler=None, + rate_limit_retry_time=0.01, + rate_limit_timeout=1, +): + """Create a Consumer with mocked infrastructure.""" + taskgroup = MagicMock() + flow = MagicMock() + backend = MagicMock() + schema = MagicMock() + handler = handler or AsyncMock() + + consumer = Consumer( + taskgroup=taskgroup, + flow=flow, + backend=backend, + topic="test-topic", + subscriber="test-sub", + schema=schema, + handler=handler, + rate_limit_retry_time=rate_limit_retry_time, + rate_limit_timeout=rate_limit_timeout, + concurrency=concurrency, + ) + + return consumer + + +def _make_msg(): + """Create a mock Pulsar message.""" + return MagicMock() + + +# --------------------------------------------------------------------------- +# Concurrency configuration tests +# --------------------------------------------------------------------------- + +class TestConcurrencyConfiguration: + + def test_default_concurrency_is_1(self): + consumer = _make_consumer() + assert consumer.concurrency == 1 + + def test_custom_concurrency(self): + consumer = _make_consumer(concurrency=10) + assert consumer.concurrency == 10 + + def test_concurrency_stored(self): + for n in [1, 5, 20, 100]: + consumer = _make_consumer(concurrency=n) + assert consumer.concurrency == n + + +class TestTaskGroupConcurrency: + + @pytest.mark.asyncio + async def test_creates_n_concurrent_tasks(self): + """consumer_run should create exactly N concurrent consume_from_queue tasks.""" + concurrency = 5 + consumer = _make_consumer(concurrency=concurrency) + + # Track how many consume_from_queue calls are made + call_count = 0 + original_running = True + + async def mock_consume(): + nonlocal call_count + call_count += 1 + # Wait a bit to let all tasks start, then signal stop + await asyncio.sleep(0.05) + consumer.running = False + + consumer.consume_from_queue = mock_consume + + # Mock the backend.create_consumer + consumer.backend.create_consumer = MagicMock(return_value=MagicMock()) + + # Run consumer_run - it will create TaskGroup with N tasks + consumer.running = True + await consumer.consumer_run() + + assert call_count == concurrency + + @pytest.mark.asyncio + async def test_single_concurrency_creates_one_task(self): + """With concurrency=1, only one consume_from_queue task is created.""" + consumer = _make_consumer(concurrency=1) + call_count = 0 + + async def mock_consume(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + consumer.running = False + + consumer.consume_from_queue = mock_consume + consumer.backend.create_consumer = MagicMock(return_value=MagicMock()) + + consumer.running = True + await consumer.consumer_run() + + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# Rate-limit retry tests +# --------------------------------------------------------------------------- + +class TestRateLimitRetry: + + @pytest.mark.asyncio + async def test_rate_limit_retries_then_succeeds(self): + """TooManyRequests should cause retry, then succeed on next attempt.""" + call_count = 0 + + async def handler_with_retry(msg, consumer_ref, flow): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise TooManyRequests("rate limited") + # Second call succeeds + + consumer = _make_consumer( + handler=handler_with_retry, + rate_limit_retry_time=0.01, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + assert call_count == 2 + consumer.consumer.acknowledge.assert_called_once_with(mock_msg) + + @pytest.mark.asyncio + async def test_rate_limit_timeout_negative_acks(self): + """If rate limit retries exhaust the timeout, message is negative-acked.""" + async def always_rate_limited(msg, consumer_ref, flow): + raise TooManyRequests("rate limited") + + consumer = _make_consumer( + handler=always_rate_limited, + rate_limit_retry_time=0.01, + rate_limit_timeout=0.05, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + consumer.consumer.negative_acknowledge.assert_called_with(mock_msg) + consumer.consumer.acknowledge.assert_not_called() + + @pytest.mark.asyncio + async def test_non_rate_limit_error_negative_acks_immediately(self): + """Non-TooManyRequests errors should negative-ack immediately (no retry).""" + call_count = 0 + + async def failing_handler(msg, consumer_ref, flow): + nonlocal call_count + call_count += 1 + raise ValueError("bad data") + + consumer = _make_consumer(handler=failing_handler) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + assert call_count == 1 + consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg) + + @pytest.mark.asyncio + async def test_successful_message_acknowledged(self): + """Successfully processed messages are acknowledged.""" + consumer = _make_consumer(handler=AsyncMock()) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + consumer.consumer.acknowledge.assert_called_once_with(mock_msg) + + +# --------------------------------------------------------------------------- +# Metrics integration +# --------------------------------------------------------------------------- + +class TestMetricsIntegration: + + @pytest.mark.asyncio + async def test_success_metric_on_success(self): + consumer = _make_consumer(handler=AsyncMock()) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + mock_metrics.record_time.return_value.__enter__ = MagicMock() + mock_metrics.record_time.return_value.__exit__ = MagicMock() + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.process.assert_called_once_with("success") + + @pytest.mark.asyncio + async def test_error_metric_on_failure(self): + async def failing(msg, c, f): + raise ValueError("fail") + + consumer = _make_consumer(handler=failing) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.process.assert_called_once_with("error") + + @pytest.mark.asyncio + async def test_rate_limit_metric_on_too_many_requests(self): + call_count = 0 + + async def handler(msg, c, f): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise TooManyRequests("limited") + + consumer = _make_consumer( + handler=handler, + rate_limit_retry_time=0.01, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + mock_metrics.record_time.return_value.__enter__ = MagicMock() + mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False) + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.rate_limit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Stop / running flag +# --------------------------------------------------------------------------- + +class TestStopBehaviour: + + @pytest.mark.asyncio + async def test_stop_sets_running_false(self): + consumer = _make_consumer() + consumer.running = True + + await consumer.stop() + + assert consumer.running is False + + def test_initial_running_state(self): + consumer = _make_consumer() + assert consumer.running is True diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py new file mode 100644 index 00000000..6a1ae8ab --- /dev/null +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -0,0 +1,136 @@ +""" +Tests for MessageDispatcher semaphore-based concurrency enforcement. + +Verifies that the dispatcher limits concurrent message processing to +max_workers via asyncio.Semaphore. +""" + +import asyncio + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.rev_gateway.dispatcher import MessageDispatcher + + +class TestSemaphoreEnforcement: + + @pytest.mark.asyncio + async def test_semaphore_limits_concurrent_processing(self): + """Only max_workers messages should be processed concurrently.""" + max_workers = 2 + dispatcher = MessageDispatcher(max_workers=max_workers) + + concurrent_count = 0 + max_concurrent = 0 + processing_event = asyncio.Event() + + async def slow_process(message): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + await asyncio.sleep(0.05) + concurrent_count -= 1 + return {"id": message.get("id"), "response": {"ok": True}} + + dispatcher._process_message = slow_process + + # Launch more tasks than max_workers + messages = [ + {"id": f"msg-{i}", "service": "test", "request": {}} + for i in range(5) + ] + + tasks = [ + asyncio.create_task(dispatcher.handle_message(m)) + for m in messages + ] + + await asyncio.gather(*tasks) + + # At no point should more than max_workers have been active + assert max_concurrent <= max_workers + + @pytest.mark.asyncio + async def test_semaphore_value_matches_max_workers(self): + for n in [1, 5, 20]: + dispatcher = MessageDispatcher(max_workers=n) + assert dispatcher.semaphore._value == n + + @pytest.mark.asyncio + async def test_active_tasks_tracked(self): + """Active tasks should be added/removed during processing.""" + dispatcher = MessageDispatcher(max_workers=5) + + task_was_tracked = False + + original_process = dispatcher._process_message + + async def tracking_process(message): + nonlocal task_was_tracked + # During processing, our task should be in active_tasks + if len(dispatcher.active_tasks) > 0: + task_was_tracked = True + return {"id": message.get("id"), "response": {"ok": True}} + + dispatcher._process_message = tracking_process + + await dispatcher.handle_message( + {"id": "test", "service": "test", "request": {}} + ) + + assert task_was_tracked + # After completion, task should be discarded + assert len(dispatcher.active_tasks) == 0 + + @pytest.mark.asyncio + async def test_semaphore_released_on_error(self): + """Semaphore should be released even if processing raises.""" + dispatcher = MessageDispatcher(max_workers=2) + + async def failing_process(message): + raise RuntimeError("process failed") + + dispatcher._process_message = failing_process + + # Should not deadlock — semaphore must be released on error + with pytest.raises(RuntimeError): + await dispatcher.handle_message( + {"id": "test", "service": "test", "request": {}} + ) + + # Semaphore should be back at max + assert dispatcher.semaphore._value == 2 + + @pytest.mark.asyncio + async def test_single_worker_serializes_processing(self): + """With max_workers=1, messages are processed one at a time.""" + dispatcher = MessageDispatcher(max_workers=1) + + order = [] + + async def ordered_process(message): + msg_id = message["id"] + order.append(f"start-{msg_id}") + await asyncio.sleep(0.02) + order.append(f"end-{msg_id}") + return {"id": msg_id, "response": {"ok": True}} + + dispatcher._process_message = ordered_process + + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] + tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + await asyncio.gather(*tasks) + + # With semaphore=1, each message should complete before next starts + # Check that no two "start" entries appear without an intervening "end" + active = 0 + max_active = 0 + for event in order: + if event.startswith("start"): + active += 1 + max_active = max(max_active, active) + elif event.startswith("end"): + active -= 1 + + assert max_active == 1 diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py new file mode 100644 index 00000000..8287427b --- /dev/null +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -0,0 +1,268 @@ +""" +Tests for Graph RAG concurrent query execution. + +Covers: execute_batch_triple_queries concurrent task spawning, +exception handling in gather, and result aggregation. +""" + +import asyncio + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_query( + triples_client=None, + entity_limit=50, + triple_limit=30, + max_subgraph_size=1000, + max_path_length=2, +): + """Create a Query object with mocked rag dependencies.""" + rag = MagicMock() + rag.triples_client = triples_client or AsyncMock() + rag.label_cache = LRUCacheWithTTL() + + query = Query( + rag=rag, + user="test-user", + collection="test-collection", + verbose=False, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, + ) + return query + + +def _make_triple(s, p, o): + """Create a simple mock triple.""" + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBatchTripleQueries: + + @pytest.mark.asyncio + async def test_three_queries_per_entity(self): + """Each entity should generate 3 concurrent queries (s, p, o positions).""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["entity-1"] + await query.execute_batch_triple_queries(entities, limit_per_entity=10) + + assert client.query_stream.call_count == 3 + + @pytest.mark.asyncio + async def test_multiple_entities_multiply_queries(self): + """N entities should produce N*3 concurrent queries.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["e1", "e2", "e3"] + await query.execute_batch_triple_queries(entities, limit_per_entity=10) + + assert client.query_stream.call_count == 9 # 3 * 3 + + @pytest.mark.asyncio + async def test_queries_executed_concurrently(self): + """All queries should run concurrently via asyncio.gather.""" + concurrent_count = 0 + max_concurrent = 0 + + async def tracking_query(**kwargs): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + await asyncio.sleep(0.02) + concurrent_count -= 1 + return [] + + client = AsyncMock() + client.query_stream = tracking_query + query = _make_query(triples_client=client) + + entities = ["e1", "e2", "e3"] + await query.execute_batch_triple_queries(entities, limit_per_entity=5) + + # All 9 queries should have run concurrently + assert max_concurrent == 9 + + @pytest.mark.asyncio + async def test_results_aggregated(self): + """Results from all queries should be combined into a single list.""" + triple_a = _make_triple("a", "p", "b") + triple_b = _make_triple("c", "p", "d") + + call_count = 0 + + async def alternating_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count % 2 == 0: + return [triple_a] + return [triple_b] + + client = AsyncMock() + client.query_stream = alternating_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries, alternating results + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_exception_in_one_query_does_not_block_others(self): + """If one query raises, other results are still collected.""" + good_triple = _make_triple("a", "p", "b") + call_count = 0 + + async def mixed_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("query failed") + return [good_triple] + + client = AsyncMock() + client.query_stream = mixed_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries: 2 succeed, 1 fails → 2 triples + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_none_results_filtered(self): + """None results from queries should be filtered out.""" + call_count = 0 + + async def sometimes_none(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + return [_make_triple("a", "p", "b")] + + client = AsyncMock() + client.query_stream = sometimes_none + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries: 1 returns None, 2 return triples + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_empty_entities_no_queries(self): + """Empty entity list should produce no queries.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries([], limit_per_entity=10) + + assert result == [] + client.query_stream.assert_not_called() + + @pytest.mark.asyncio + async def test_query_params_correct(self): + """Each query should use correct s/p/o positions and params.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["ent-1"] + await query.execute_batch_triple_queries(entities, limit_per_entity=15) + + calls = client.query_stream.call_args_list + assert len(calls) == 3 + + # First call: s=entity, p=None, o=None + assert calls[0].kwargs["s"] == "ent-1" + assert calls[0].kwargs["p"] is None + assert calls[0].kwargs["o"] is None + assert calls[0].kwargs["limit"] == 15 + assert calls[0].kwargs["user"] == "test-user" + assert calls[0].kwargs["collection"] == "test-collection" + assert calls[0].kwargs["batch_size"] == 20 + + # Second call: s=None, p=entity, o=None + assert calls[1].kwargs["s"] is None + assert calls[1].kwargs["p"] == "ent-1" + assert calls[1].kwargs["o"] is None + + # Third call: s=None, p=None, o=entity + assert calls[2].kwargs["s"] is None + assert calls[2].kwargs["p"] is None + assert calls[2].kwargs["o"] == "ent-1" + + +class TestLRUCacheWithTTL: + + def test_put_and_get(self): + cache = LRUCacheWithTTL(max_size=10, ttl=60) + cache.put("key1", "value1") + assert cache.get("key1") == "value1" + + def test_get_missing_returns_none(self): + cache = LRUCacheWithTTL() + assert cache.get("nonexistent") is None + + def test_max_size_eviction(self): + cache = LRUCacheWithTTL(max_size=2, ttl=60) + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) # Should evict "a" + assert cache.get("a") is None + assert cache.get("b") == 2 + assert cache.get("c") == 3 + + def test_lru_order(self): + cache = LRUCacheWithTTL(max_size=2, ttl=60) + cache.put("a", 1) + cache.put("b", 2) + cache.get("a") # Access "a" — now "b" is LRU + cache.put("c", 3) # Should evict "b" + assert cache.get("a") == 1 + assert cache.get("b") is None + assert cache.get("c") == 3 + + def test_ttl_expiration(self): + cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry + cache.put("key", "value") + # With TTL=0, any time check > 0 means expired + import time + time.sleep(0.01) + assert cache.get("key") is None + + def test_update_existing_key(self): + cache = LRUCacheWithTTL(max_size=10, ttl=60) + cache.put("key", "v1") + cache.put("key", "v2") + assert cache.get("key") == "v2" diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 96c9c427..76095690 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -73,7 +73,6 @@ def sample_triples(): id="test-doc-id", user="test-user", collection="default", # This should be overridden - metadata=[] ), triples=[ Triple( @@ -93,12 +92,11 @@ def sample_graph_embeddings(): id="test-doc-id", user="test-user", collection="default", # This should be overridden - metadata=[] ), entities=[ EntityEmbeddings( entity=Term(type=IRI, iri="http://example.org/john"), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) ] ) diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py index b40accdf..a3ca3514 100644 --- a/tests/unit/test_decoding/test_pdf_decoder.py +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -12,218 +12,184 @@ from trustgraph.decoding.pdf.pdf_decoder import Processor from trustgraph.schema import Document, TextDocument, Metadata +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] + self.pubsub = MagicMock() + self.taskgroup = params.get('taskgroup', MagicMock()) + + class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): """Test PDF decoder processor functionality""" - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_processor_initialization(self, mock_flow_init): + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): """Test PDF decoder processor initialization""" - # Arrange - mock_flow_init.return_value = None - config = { 'id': 'test-pdf-decoder', 'taskgroup': AsyncMock() } - # Act - with patch.object(Processor, 'register_specification') as mock_register: - processor = Processor(**config) + processor = Processor(**config) - # Assert - mock_flow_init.assert_called_once() - # Verify register_specification was called twice (consumer and producer) - assert mock_register.call_count == 2 - # Check consumer spec - consumer_call = mock_register.call_args_list[0] - consumer_spec = consumer_call[0][0] - assert consumer_spec.name == "input" - assert consumer_spec.schema == Document - assert consumer_spec.handler == processor.on_message - - # Check producer spec - producer_call = mock_register.call_args_list[1] - producer_spec = producer_call[0][0] - assert producer_spec.name == "output" - assert producer_spec.schema == TextDocument + consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')] + assert len(consumer_specs) >= 1 + assert consumer_specs[0].name == "input" + assert consumer_specs[0].schema == Document + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_on_message_success(self, mock_flow_init, mock_pdf_loader_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): """Test successful PDF processing""" - # Arrange - mock_flow_init.return_value = None - # Mock PDF content pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') - + # Mock PyPDFLoader mock_loader = MagicMock() mock_page1 = MagicMock(page_content="Page 1 content") mock_page2 = MagicMock(page_content="Page 2 content") mock_loader.load.return_value = [mock_page1, mock_page2] mock_pdf_loader_class.return_value = mock_loader - + # Mock message mock_metadata = Metadata(id="test-doc") mock_document = Document(metadata=mock_metadata, data=pdf_base64) mock_msg = MagicMock() mock_msg.value.return_value = mock_document - - # Mock flow - needs to be a callable that returns an object with send method + + # Mock flow - separate mocks for output and triples mock_output_flow = AsyncMock() - mock_flow = MagicMock(return_value=mock_output_flow) - + mock_triples_flow = AsyncMock() + mock_flow = MagicMock(side_effect=lambda name: { + "output": mock_output_flow, + "triples": mock_triples_flow, + }.get(name)) + config = { 'id': 'test-pdf-decoder', 'taskgroup': AsyncMock() } - with patch.object(Processor, 'register_specification'): - processor = Processor(**config) + processor = Processor(**config) + + # Mock save_child_document to avoid waiting for librarian response + processor.save_child_document = AsyncMock(return_value="mock-doc-id") - # Act await processor.on_message(mock_msg, None, mock_flow) - # Assert - # Verify PyPDFLoader was called - mock_pdf_loader_class.assert_called_once() - mock_loader.load.assert_called_once() - # Verify output was sent for each page assert mock_output_flow.send.call_count == 2 - - # Check first page output - first_call = mock_output_flow.send.call_args_list[0] - first_output = first_call[0][0] - assert isinstance(first_output, TextDocument) - assert first_output.metadata == mock_metadata - assert first_output.text == b"Page 1 content" - - # Check second page output - second_call = mock_output_flow.send.call_args_list[1] - second_output = second_call[0][0] - assert isinstance(second_output, TextDocument) - assert second_output.metadata == mock_metadata - assert second_output.text == b"Page 2 content" + # Verify triples were sent for each page (provenance) + assert mock_triples_flow.send.call_count == 2 + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_on_message_empty_pdf(self, mock_flow_init, mock_pdf_loader_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): """Test handling of empty PDF""" - # Arrange - mock_flow_init.return_value = None - - # Mock PDF content pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') - - # Mock PyPDFLoader with no pages + mock_loader = MagicMock() mock_loader.load.return_value = [] mock_pdf_loader_class.return_value = mock_loader - - # Mock message + mock_metadata = Metadata(id="test-doc") mock_document = Document(metadata=mock_metadata, data=pdf_base64) mock_msg = MagicMock() mock_msg.value.return_value = mock_document - - # Mock flow - needs to be a callable that returns an object with send method + mock_output_flow = AsyncMock() mock_flow = MagicMock(return_value=mock_output_flow) - + config = { 'id': 'test-pdf-decoder', 'taskgroup': AsyncMock() } - with patch.object(Processor, 'register_specification'): - processor = Processor(**config) + processor = Processor(**config) - # Act await processor.on_message(mock_msg, None, mock_flow) - # Assert - # Verify PyPDFLoader was called - mock_pdf_loader_class.assert_called_once() - mock_loader.load.assert_called_once() - - # Verify no output was sent mock_output_flow.send.assert_not_called() + @patch('trustgraph.base.chunking_service.Consumer') + @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') + @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_on_message_unicode_content(self, mock_flow_init, mock_pdf_loader_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): """Test handling of unicode content in PDF""" - # Arrange - mock_flow_init.return_value = None - - # Mock PDF content pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') - - # Mock PyPDFLoader with unicode content + mock_loader = MagicMock() mock_page = MagicMock(page_content="Page with unicode: 你好世界 🌍") mock_loader.load.return_value = [mock_page] mock_pdf_loader_class.return_value = mock_loader - - # Mock message + mock_metadata = Metadata(id="test-doc") mock_document = Document(metadata=mock_metadata, data=pdf_base64) mock_msg = MagicMock() mock_msg.value.return_value = mock_document - - # Mock flow - needs to be a callable that returns an object with send method + + # Mock flow - separate mocks for output and triples mock_output_flow = AsyncMock() - mock_flow = MagicMock(return_value=mock_output_flow) - + mock_triples_flow = AsyncMock() + mock_flow = MagicMock(side_effect=lambda name: { + "output": mock_output_flow, + "triples": mock_triples_flow, + }.get(name)) + config = { 'id': 'test-pdf-decoder', 'taskgroup': AsyncMock() } - with patch.object(Processor, 'register_specification'): - processor = Processor(**config) + processor = Processor(**config) + + # Mock save_child_document to avoid waiting for librarian response + processor.save_child_document = AsyncMock(return_value="mock-doc-id") - # Act await processor.on_message(mock_msg, None, mock_flow) - # Assert - # Verify output was sent mock_output_flow.send.assert_called_once() - - # Check output call_args = mock_output_flow.send.call_args[0][0] - assert isinstance(call_args, TextDocument) - assert call_args.text == "Page with unicode: 你好世界 🌍".encode('utf-8') + # PDF decoder now forwards document_id, chunker fetches content from librarian + assert call_args.document_id == "test-doc/p1" + assert call_args.text == b"" # Content stored in librarian, not inline @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') def test_add_args(self, mock_parent_add_args): """Test add_args calls parent method""" - # Arrange mock_parser = MagicMock() - - # Act Processor.add_args(mock_parser) - - # Assert mock_parent_add_args.assert_called_once_with(mock_parser) @patch('trustgraph.decoding.pdf.pdf_decoder.Processor.launch') def test_run(self, mock_launch): """Test run function""" - # Act from trustgraph.decoding.pdf.pdf_decoder import run run() - - # Assert - mock_launch.assert_called_once_with("pdf-decoder", - "\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n") + mock_launch.assert_called_once_with("pdf-decoder", + "\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n\nSupports both inline document data and fetching from librarian via Pulsar\nfor large documents.\n") if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/unit/test_direct/test_entity_centric_kg.py b/tests/unit/test_direct/test_entity_centric_kg.py index 5f64b581..4a74b35b 100644 --- a/tests/unit/test_direct/test_entity_centric_kg.py +++ b/tests/unit/test_direct/test_entity_centric_kg.py @@ -305,9 +305,8 @@ class TestEntityCentricKnowledgeGraph: mock_session.execute.assert_called() - def test_graph_wildcard_returns_all_graphs(self, entity_kg): - """Test that g='*' returns quads from all graphs""" - from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD + def test_graph_none_returns_all_graphs(self, entity_kg): + """Test that g=None returns quads from all graphs""" kg, mock_session = entity_kg mock_result = [ @@ -320,7 +319,7 @@ class TestEntityCentricKnowledgeGraph: ] mock_session.execute.return_value = mock_result - results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD) + results = kg.get_s('test_collection', 'http://example.org/Alice', g=None) # Should return quads from both graphs assert len(results) == 2 @@ -547,21 +546,21 @@ class TestServiceHelperFunctions: """Test cases for helper functions in service.py""" def test_create_term_with_uri_otype(self): - """Test create_term creates IRI Term for otype='u'""" + """Test create_term creates IRI Term for term_type='u'""" from trustgraph.query.triples.cassandra.service import create_term from trustgraph.schema import IRI - term = create_term('http://example.org/Alice', otype='u') + term = create_term('http://example.org/Alice', term_type='u') assert term.type == IRI assert term.iri == 'http://example.org/Alice' def test_create_term_with_literal_otype(self): - """Test create_term creates LITERAL Term for otype='l'""" + """Test create_term creates LITERAL Term for term_type='l'""" from trustgraph.query.triples.cassandra.service import create_term from trustgraph.schema import LITERAL - term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en') + term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en') assert term.type == LITERAL assert term.value == 'Alice Smith' @@ -569,14 +568,24 @@ class TestServiceHelperFunctions: assert term.language == 'en' def test_create_term_with_triple_otype(self): - """Test create_term creates IRI Term for otype='t'""" + """Test create_term creates TRIPLE Term for term_type='t' with valid JSON""" from trustgraph.query.triples.cassandra.service import create_term - from trustgraph.schema import IRI + from trustgraph.schema import TRIPLE, IRI + import json - term = create_term('http://example.org/statement1', otype='t') + # Valid JSON triple data + triple_json = json.dumps({ + "s": {"type": "i", "iri": "http://example.org/Alice"}, + "p": {"type": "i", "iri": "http://example.org/knows"}, + "o": {"type": "i", "iri": "http://example.org/Bob"}, + }) - assert term.type == IRI - assert term.iri == 'http://example.org/statement1' + term = create_term(triple_json, term_type='t') + + assert term.type == TRIPLE + assert term.triple is not None + assert term.triple.s.type == IRI + assert term.triple.s.iri == "http://example.org/Alice" def test_create_term_heuristic_fallback_uri(self): """Test create_term uses URL heuristic when otype not provided""" diff --git a/tests/unit/test_direct/test_entity_centric_write_amplification.py b/tests/unit/test_direct/test_entity_centric_write_amplification.py new file mode 100644 index 00000000..1c9ad1a8 --- /dev/null +++ b/tests/unit/test_direct/test_entity_centric_write_amplification.py @@ -0,0 +1,441 @@ +""" +Tests for entity-centric KG write amplification, delete collection batching, +in-partition filtering, and term type metadata round-trips. + +Complements test_entity_centric_kg.py with deeper verification of the +2-table schema mechanics. +""" + +import pytest +from unittest.mock import MagicMock, patch, call + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_cassandra(): + """Provide mocked Cassandra cluster, session, and BatchStatement.""" + with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \ + patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls: + + mock_cluster = MagicMock() + mock_session = MagicMock() + mock_cluster.connect.return_value = mock_session + mock_cls.return_value = mock_cluster + + # Track batch.add calls per batch instance + batches = [] + + def make_batch(): + batch = MagicMock() + batch._adds = [] + original_add = batch.add + + def tracking_add(stmt, params): + batch._adds.append((stmt, params)) + + batch.add = tracking_add + batches.append(batch) + return batch + + mock_batch_cls.side_effect = make_batch + + yield { + "cluster_cls": mock_cls, + "cluster": mock_cluster, + "session": mock_session, + "batch_cls": mock_batch_cls, + "batches": batches, + } + + +@pytest.fixture +def entity_kg(mock_cassandra): + """Create an EntityCentricKnowledgeGraph with mocked Cassandra.""" + from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph + kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks') + return kg, mock_cassandra + + +# --------------------------------------------------------------------------- +# Write amplification: row count verification +# --------------------------------------------------------------------------- + +class TestWriteAmplification: + + def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg): + """URI object → S + P + O + G-if-non-default entity rows + 1 collection row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/Alice', + p='http://ex.org/knows', + o='http://ex.org/Bob', + g='http://ex.org/g1', + otype='u', + ) + + # Should be exactly one batch + assert len(ctx["batches"]) == 1 + batch = ctx["batches"][0] + + # 4 entity rows (S, P, O, G) + 1 collection row = 5 + assert len(batch._adds) == 5 + + # Check roles assigned + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'S' in roles + assert 'P' in roles + assert 'O' in roles + assert 'G' in roles + + def test_literal_object_produces_3_entity_rows(self, entity_kg): + """Literal object → S + P entity rows (no O row) + collection row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/Alice', + p='http://ex.org/name', + o='Alice Smith', + g=None, # default graph + otype='l', + ) + + batch = ctx["batches"][0] + + # S + P entity rows + 1 collection = 3 (no O row for literal, no G for default) + assert len(batch._adds) == 3 + + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'S' in roles + assert 'P' in roles + assert 'O' not in roles + assert 'G' not in roles + + def test_triple_otype_gets_object_entity_row(self, entity_kg): + """otype='t' (quoted triple) → object gets entity row like URI.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='{"s":{},"p":{},"o":{}}', + g=None, + otype='t', + ) + + batch = ctx["batches"][0] + + # S + P + O entity rows + collection = 4 (no G for default graph) + assert len(batch._adds) == 4 + + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'O' in roles + + def test_default_graph_no_g_row(self, entity_kg): + """Default graph (g=None) → no G entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + g=None, + otype='u', + ) + + batch = ctx["batches"][0] + + # S + P + O entity rows + collection = 4 (no G) + assert len(batch._adds) == 4 + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'G' not in roles + + def test_non_default_graph_gets_g_row(self, entity_kg): + """Non-default graph → gets G entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + g='http://ex.org/graph1', + otype='u', + ) + + batch = ctx["batches"][0] + + # S + P + O + G entity rows + collection = 5 + assert len(batch._adds) == 5 + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'G' in roles + + def test_dtype_and_lang_passed_to_all_rows(self, entity_kg): + """dtype and lang should be stored in every entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/label', + o='thing', + g=None, + otype='l', + dtype='xsd:string', + lang='en', + ) + + batch = ctx["batches"][0] + + # Check entity rows carry dtype and lang + for _, params in batch._adds: + if len(params) == 10: + # Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang) + assert params[8] == 'xsd:string' + assert params[9] == 'en' + + +# --------------------------------------------------------------------------- +# In-partition filtering: get_os, get_spo +# --------------------------------------------------------------------------- + +class TestInPartitionFiltering: + + def test_get_os_filters_by_object(self, entity_kg): + """get_os should filter results by matching object value.""" + kg, ctx = entity_kg + + # Simulate rows returned from subject partition (all have same s) + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice') + + # Only the Bob row should pass the filter + assert len(results) == 1 + assert results[0].o == 'http://ex.org/Bob' + assert results[0].p == 'http://ex.org/knows' + + def test_get_os_returns_empty_when_no_match(self, entity_kg): + """get_os should return empty list when object doesn't match any row.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice') + + assert len(results) == 0 + + def test_get_spo_filters_by_object(self, entity_kg): + """get_spo should filter results by matching object value.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''), + MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_spo( + 'col', 'http://ex.org/Alice', 'http://ex.org/knows', + 'http://ex.org/Bob', + ) + + assert len(results) == 1 + assert results[0].o == 'http://ex.org/Bob' + + def test_get_os_with_graph_filter(self, entity_kg): + """get_os with specific graph should filter both object and graph.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='http://ex.org/g1', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='http://ex.org/g2', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os( + 'col', 'http://ex.org/Bob', 'http://ex.org/Alice', + g='http://ex.org/g1', + ) + + assert len(results) == 1 + assert results[0].g == 'http://ex.org/g1' + + +# --------------------------------------------------------------------------- +# Delete collection batching +# --------------------------------------------------------------------------- + +class TestDeleteCollectionBatching: + + def test_extracts_unique_entities_from_quads(self, entity_kg): + """delete_collection should extract s, p, and URI o as entities.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows', + o='http://ex.org/B', otype='u', dtype='', lang=''), + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name', + o='Alice', otype='l', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + # Unique entities: A, knows, B, name (literal 'Alice' excluded) + # The batches should include entity partition deletes + all_adds = [] + for batch in ctx["batches"]: + all_adds.extend(batch._adds) + + # We expect entity deletes + collection row deletes + metadata delete + # Just verify the function completes and calls execute + assert ctx["session"].execute.called + + def test_literal_objects_not_treated_as_entities(self, entity_kg): + """Literal objects (otype='l') should not get entity partition deletes.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name', + o='Alice', otype='l', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + # Entity partition deletes should only include A and name, not Alice + entity_deletes = [] + for batch in ctx["batches"]: + for _, params in batch._adds: + if len(params) == 2: # delete_entity_partition takes (collection, entity) + entity_deletes.append(params[1]) + + assert 'http://ex.org/A' in entity_deletes + assert 'http://ex.org/name' in entity_deletes + assert 'Alice' not in entity_deletes + + def test_non_default_graph_treated_as_entity(self, entity_kg): + """Non-default graphs should get entity partition deletes.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='http://ex.org/g1', s='http://ex.org/A', + p='http://ex.org/p', o='http://ex.org/B', + otype='u', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + entity_deletes = [] + for batch in ctx["batches"]: + for _, params in batch._adds: + if len(params) == 2: + entity_deletes.append(params[1]) + + assert 'http://ex.org/g1' in entity_deletes + + def test_empty_collection_delete_completes(self, entity_kg): + """Deleting an empty collection should not error.""" + kg, ctx = entity_kg + + ctx["session"].execute.return_value = [] + ctx["batches"].clear() + + # Should not raise + kg.delete_collection('empty-col') + + +# --------------------------------------------------------------------------- +# Term type metadata round-trip +# --------------------------------------------------------------------------- + +class TestTermTypeMetadata: + + def test_query_results_include_otype(self, entity_kg): + """Query results should include otype from Cassandra rows.""" + kg, ctx = entity_kg + from trustgraph.direct.cassandra_kg import QuadResult + + mock_rows = [ + MagicMock(p='http://ex.org/name', o='Alice', + d='', otype='l', dtype='xsd:string', lang='en', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_s('col', 'http://ex.org/Alice') + + assert len(results) == 1 + assert results[0].otype == 'l' + assert results[0].dtype == 'xsd:string' + assert results[0].lang == 'en' + + def test_auto_detect_otype_uri(self, entity_kg): + """Auto-detect should classify http:// as URI.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + ) + + batch = ctx["batches"][0] + # Check otype in entity rows (position 4) + for _, params in batch._adds: + if len(params) == 10: + assert params[4] == 'u' + + def test_auto_detect_otype_literal(self, entity_kg): + """Auto-detect should classify non-http:// as literal.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='plain text', + ) + + batch = ctx["batches"][0] + for _, params in batch._adds: + if len(params) == 10: + assert params[4] == 'l' diff --git a/tests/unit/test_embeddings/test_document_embeddings_processor.py b/tests/unit/test_embeddings/test_document_embeddings_processor.py new file mode 100644 index 00000000..9cd93c4f --- /dev/null +++ b/tests/unit/test_embeddings/test_document_embeddings_processor.py @@ -0,0 +1,164 @@ +""" +Tests for document embeddings processor — single-chunk embedding via batch API. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.embeddings.document_embeddings.embeddings import Processor +from trustgraph.schema import ( + Chunk, DocumentEmbeddings, ChunkEmbeddings, + EmbeddingsRequest, EmbeddingsResponse, Metadata, +) + + +@pytest.fixture +def processor(): + return Processor( + taskgroup=AsyncMock(), + id="test-doc-embeddings", + ) + + +def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", + user="test", collection="default"): + metadata = Metadata(id=doc_id, user=user, collection=collection) + value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) + msg = MagicMock() + msg.value.return_value = value + return msg + + +class TestDocumentEmbeddingsProcessor: + + @pytest.mark.asyncio + async def test_sends_single_text_as_list(self, processor): + """Document embeddings should wrap single chunk in a list for the API.""" + msg = _make_chunk_message("test chunk text") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.1, 0.2, 0.3]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # Should send EmbeddingsRequest with texts=[chunk] + mock_request.assert_called_once() + req = mock_request.call_args[0][0] + assert isinstance(req, EmbeddingsRequest) + assert req.texts == ["test chunk text"] + + @pytest.mark.asyncio + async def test_extracts_first_vector(self, processor): + """Should use vectors[0] from the response.""" + msg = _make_chunk_message("chunk") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[1.0, 2.0, 3.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert isinstance(result, DocumentEmbeddings) + assert len(result.chunks) == 1 + assert result.chunks[0].vector == [1.0, 2.0, 3.0] + + @pytest.mark.asyncio + async def test_empty_vectors_response(self, processor): + """Should handle empty vectors response gracefully.""" + msg = _make_chunk_message("chunk") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.chunks[0].vector == [] + + @pytest.mark.asyncio + async def test_chunk_id_is_document_id(self, processor): + """ChunkEmbeddings should use document_id as chunk_id.""" + msg = _make_chunk_message(doc_id="my-doc-42") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.chunks[0].chunk_id == "my-doc-42" + + @pytest.mark.asyncio + async def test_metadata_preserved(self, processor): + """Output should carry the original metadata.""" + msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.metadata.user == "alice" + assert result.metadata.collection == "reports" + assert result.metadata.id == "d1" + + @pytest.mark.asyncio + async def test_error_propagates(self, processor): + """Embedding errors should propagate for retry.""" + msg = _make_chunk_message() + + mock_request = AsyncMock(side_effect=RuntimeError("service down")) + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + return MagicMock() + + with pytest.raises(RuntimeError, match="service down"): + await processor.on_message(msg, MagicMock(), flow) diff --git a/tests/unit/test_embeddings/test_embeddings_client.py b/tests/unit/test_embeddings/test_embeddings_client.py new file mode 100644 index 00000000..84305d7a --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_client.py @@ -0,0 +1,109 @@ +""" +Tests for EmbeddingsClient — the client interface for batch embeddings. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.base.embeddings_client import EmbeddingsClient +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error + + +class TestEmbeddingsClient: + + @pytest.mark.asyncio + async def test_embed_sends_request_and_returns_vectors(self): + """embed() should send an EmbeddingsRequest and return vectors.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, + vectors=[[0.1, 0.2], [0.3, 0.4]], + )) + + result = await client.embed(texts=["hello", "world"]) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + client.request.assert_called_once() + req = client.request.call_args[0][0] + assert isinstance(req, EmbeddingsRequest) + assert req.texts == ["hello", "world"] + + @pytest.mark.asyncio + async def test_embed_single_text(self): + """embed() should work with a single text.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, + vectors=[[1.0, 2.0, 3.0]], + )) + + result = await client.embed(texts=["single"]) + + assert result == [[1.0, 2.0, 3.0]] + + @pytest.mark.asyncio + async def test_embed_raises_on_error_response(self): + """embed() should raise RuntimeError when response contains an error.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=Error(type="embeddings-error", message="model not found"), + vectors=[], + )) + + with pytest.raises(RuntimeError, match="model not found"): + await client.embed(texts=["test"]) + + @pytest.mark.asyncio + async def test_embed_passes_timeout(self): + """embed() should pass timeout to the underlying request.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]], + )) + + await client.embed(texts=["test"], timeout=60) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 60 + + @pytest.mark.asyncio + async def test_embed_default_timeout(self): + """embed() should use 300s default timeout.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]], + )) + + await client.embed(texts=["test"]) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 300 + + @pytest.mark.asyncio + async def test_embed_empty_texts(self): + """embed() with empty list should still make the request.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[], + )) + + result = await client.embed(texts=[]) + + assert result == [] + + @pytest.mark.asyncio + async def test_embed_large_batch(self): + """embed() should handle large batches.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + n = 100 + vectors = [[float(i)] for i in range(n)] + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=vectors, + )) + + texts = [f"text {i}" for i in range(n)] + result = await client.embed(texts=texts) + + assert len(result) == n + req = client.request.call_args[0][0] + assert len(req.texts) == n diff --git a/tests/unit/test_embeddings/test_embeddings_service_request.py b/tests/unit/test_embeddings/test_embeddings_service_request.py new file mode 100644 index 00000000..c57fae16 --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_service_request.py @@ -0,0 +1,135 @@ +""" +Tests for EmbeddingsService.on_request — the request handler that dispatches +to on_embeddings and sends responses. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.base import EmbeddingsService +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error +from trustgraph.exceptions import TooManyRequests + + +class StubEmbeddingsService(EmbeddingsService): + """Minimal concrete implementation for testing on_request.""" + + def __init__(self, embed_result=None, embed_error=None): + # Skip super().__init__ to avoid taskgroup/registration + self.embed_result = embed_result or [[0.1, 0.2]] + self.embed_error = embed_error + + async def on_embeddings(self, texts, model=None): + if self.embed_error: + raise self.embed_error + return self.embed_result + + +def _make_msg(texts, msg_id="req-1"): + request = EmbeddingsRequest(texts=texts) + msg = MagicMock() + msg.value.return_value = request + msg.properties.return_value = {"id": msg_id} + return msg + + +def _make_flow(model="test-model"): + mock_response_producer = AsyncMock() + mock_flow = MagicMock() + + def flow_callable(name): + if name == "model": + return model + if name == "response": + return mock_response_producer + return MagicMock() + + flow_callable.producer = {"response": mock_response_producer} + return flow_callable, mock_response_producer + + +class TestEmbeddingsServiceOnRequest: + + @pytest.mark.asyncio + async def test_successful_request(self): + """on_request should call on_embeddings and send response.""" + service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]]) + msg = _make_msg(["hello", "world"], msg_id="r1") + flow, mock_response = _make_flow(model="my-model") + + await service.on_request(msg, MagicMock(), flow) + + mock_response.send.assert_called_once() + resp = mock_response.send.call_args[0][0] + assert isinstance(resp, EmbeddingsResponse) + assert resp.error is None + assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]] + + # Check id is passed through + props = mock_response.send.call_args[1]["properties"] + assert props["id"] == "r1" + + @pytest.mark.asyncio + async def test_passes_model_from_flow(self): + """on_request should pass model parameter from flow to on_embeddings.""" + calls = [] + + class TrackingService(EmbeddingsService): + def __init__(self): + pass + + async def on_embeddings(self, texts, model=None): + calls.append({"texts": texts, "model": model}) + return [[0.0]] + + service = TrackingService() + msg = _make_msg(["test"]) + flow, _ = _make_flow(model="custom-model-v2") + + await service.on_request(msg, MagicMock(), flow) + + assert len(calls) == 1 + assert calls[0]["model"] == "custom-model-v2" + assert calls[0]["texts"] == ["test"] + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + """Non-rate-limit errors should send an error response.""" + service = StubEmbeddingsService( + embed_error=ValueError("dimension mismatch") + ) + msg = _make_msg(["test"], msg_id="r2") + flow, mock_response = _make_flow() + + await service.on_request(msg, MagicMock(), flow) + + mock_response.send.assert_called_once() + resp = mock_response.send.call_args[0][0] + assert resp.error is not None + assert resp.error.type == "embeddings-error" + assert "dimension mismatch" in resp.error.message + assert resp.vectors == [] + + @pytest.mark.asyncio + async def test_rate_limit_propagates(self): + """TooManyRequests should propagate (not caught as error response).""" + service = StubEmbeddingsService( + embed_error=TooManyRequests("rate limited") + ) + msg = _make_msg(["test"]) + flow, _ = _make_flow() + + with pytest.raises(TooManyRequests): + await service.on_request(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_message_id_preserved(self): + """The request message id should be forwarded in the response properties.""" + service = StubEmbeddingsService() + msg = _make_msg(["test"], msg_id="unique-id-42") + flow, mock_response = _make_flow() + + await service.on_request(msg, MagicMock(), flow) + + props = mock_response.send.call_args[1]["properties"] + assert props["id"] == "unique-id-42" diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py index 1c1fb883..f4e456cb 100644 --- a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -103,7 +103,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): mock_text_embedding_class.reset_mock() # Act - result = await processor.on_embeddings("test text") + result = await processor.on_embeddings(["test text"]) # Assert mock_fastembed_instance.embed.assert_called_once_with(["test text"]) @@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): mock_text_embedding_class.reset_mock() # Act - result = await processor.on_embeddings("test text", model="custom-model") + result = await processor.on_embeddings(["test text"], model="custom-model") # Assert mock_text_embedding_class.assert_called_once_with(model_name="custom-model") @@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): initial_call_count = mock_text_embedding_class.call_count # Act - switch between models - await processor.on_embeddings("text1", model="model-a") + await processor.on_embeddings(["text1"], model="model-a") call_count_after_a = mock_text_embedding_class.call_count - await processor.on_embeddings("text2", model="model-a") # Same, no reload + await processor.on_embeddings(["text2"], model="model-a") # Same, no reload call_count_after_a_repeat = mock_text_embedding_class.call_count - await processor.on_embeddings("text3", model="model-b") # Different, reload + await processor.on_embeddings(["text3"], model="model-b") # Different, reload call_count_after_b = mock_text_embedding_class.call_count - await processor.on_embeddings("text4", model="model-a") # Back to A, reload + await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload call_count_after_a_again = mock_text_embedding_class.call_count # Assert @@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): initial_count = mock_text_embedding_class.call_count # Act - result = await processor.on_embeddings("test text", model=None) + result = await processor.on_embeddings(["test text"], model=None) # Assert # No reload, using cached default diff --git a/tests/unit/test_embeddings/test_graph_embeddings_processor.py b/tests/unit/test_embeddings/test_graph_embeddings_processor.py new file mode 100644 index 00000000..5d535349 --- /dev/null +++ b/tests/unit/test_embeddings/test_graph_embeddings_processor.py @@ -0,0 +1,233 @@ +""" +Tests for graph embeddings processor — batch embedding of entity contexts. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.embeddings.graph_embeddings.embeddings import Processor +from trustgraph.schema import ( + EntityContexts, EntityEmbeddings, GraphEmbeddings, + Term, IRI, Metadata, +) + + +@pytest.fixture +def processor(): + return Processor( + taskgroup=AsyncMock(), + id="test-graph-embeddings", + batch_size=3, + ) + + +def _make_entity_context(name, context, chunk_id="chunk-1"): + """Create an entity context for testing.""" + entity = Term(type=IRI, iri=f"urn:entity:{name}") + return MagicMock(entity=entity, context=context, chunk_id=chunk_id) + + +def _make_message(entities, doc_id="doc-1", user="test", collection="default"): + metadata = Metadata(id=doc_id, user=user, collection=collection) + value = EntityContexts(metadata=metadata, entities=entities) + msg = MagicMock() + msg.value.return_value = value + return msg + + +class TestGraphEmbeddingsInit: + + def test_default_batch_size(self): + p = Processor(taskgroup=AsyncMock(), id="test") + assert p.batch_size == 5 + + def test_custom_batch_size(self): + p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20) + assert p.batch_size == 20 + + +class TestGraphEmbeddingsBatchProcessing: + + @pytest.mark.asyncio + async def test_single_batch_call_for_all_entities(self, processor): + """All entity contexts should be embedded in a single API call.""" + entities = [ + _make_entity_context("Alice", "Alice is a person"), + _make_entity_context("Bob", "Bob is a developer"), + _make_entity_context("Acme", "Acme is a company"), + ] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[ + [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], + ]) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # Single batch call with all three texts + mock_embed.assert_called_once_with( + texts=["Alice is a person", "Bob is a developer", "Acme is a company"] + ) + + @pytest.mark.asyncio + async def test_vectors_paired_with_correct_entities(self, processor): + """Each vector should be paired with its corresponding entity.""" + entities = [ + _make_entity_context("Alice", "ctx-A", chunk_id="c1"), + _make_entity_context("Bob", "ctx-B", chunk_id="c2"), + ] + msg = _make_message(entities) + + vectors = [[1.0, 2.0], [3.0, 4.0]] + mock_embed = AsyncMock(return_value=vectors) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # With batch_size=3, all 2 entities fit in one output message + mock_output.send.assert_called_once() + result = mock_output.send.call_args[0][0] + assert isinstance(result, GraphEmbeddings) + assert len(result.entities) == 2 + assert result.entities[0].vector == [1.0, 2.0] + assert result.entities[0].entity.iri == "urn:entity:Alice" + assert result.entities[0].chunk_id == "c1" + assert result.entities[1].vector == [3.0, 4.0] + assert result.entities[1].entity.iri == "urn:entity:Bob" + + @pytest.mark.asyncio + async def test_output_batching(self, processor): + """Output should be split into batches of batch_size.""" + # batch_size=3, 7 entities -> 3 output messages (3+3+1) + entities = [ + _make_entity_context(f"E{i}", f"context {i}") + for i in range(7) + ] + msg = _make_message(entities) + + vectors = [[float(i)] for i in range(7)] + mock_embed = AsyncMock(return_value=vectors) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + assert mock_output.send.call_count == 3 + # First batch has 3 entities + batch1 = mock_output.send.call_args_list[0][0][0] + assert len(batch1.entities) == 3 + # Second batch has 3 entities + batch2 = mock_output.send.call_args_list[1][0][0] + assert len(batch2.entities) == 3 + # Third batch has 1 entity + batch3 = mock_output.send.call_args_list[2][0][0] + assert len(batch3.entities) == 1 + + @pytest.mark.asyncio + async def test_output_batches_preserve_metadata(self, processor): + """Each output batch should carry the original metadata.""" + entities = [ + _make_entity_context(f"E{i}", f"ctx {i}") + for i in range(5) + ] + msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") + + mock_embed = AsyncMock(return_value=[[0.0]] * 5) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + for call in mock_output.send.call_args_list: + result = call[0][0] + assert result.metadata.id == "doc-42" + assert result.metadata.user == "alice" + assert result.metadata.collection == "main" + + @pytest.mark.asyncio + async def test_single_entity(self, processor): + """Single entity should work with one embed call and one output.""" + entities = [_make_entity_context("Solo", "solo context")] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]]) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + mock_embed.assert_called_once_with(texts=["solo context"]) + mock_output.send.assert_called_once() + + @pytest.mark.asyncio + async def test_embed_error_propagates(self, processor): + """Embedding service errors should propagate for retry.""" + entities = [_make_entity_context("E", "ctx")] + msg = _make_message(entities) + + mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed")) + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + return MagicMock() + + with pytest.raises(RuntimeError, match="embedding failed"): + await processor.on_message(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_exact_batch_size(self, processor): + """When entity count equals batch_size, exactly one output message.""" + entities = [ + _make_entity_context(f"E{i}", f"ctx {i}") + for i in range(3) # batch_size=3 + ] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[[0.0]] * 3) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + mock_output.send.assert_called_once() + assert len(mock_output.send.call_args[0][0].entities) == 3 diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py index ca0f44bf..d52a58c6 100644 --- a/tests/unit/test_embeddings/test_ollama_dynamic_model.py +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -53,12 +53,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text") + result = await processor.on_embeddings(["test text"]) # Assert mock_ollama_client.embed.assert_called_once_with( model="test-model", - input="test text" + input=["test text"] ) assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @@ -79,12 +79,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text", model="custom-model") + result = await processor.on_embeddings(["test text"], model="custom-model") # Assert mock_ollama_client.embed.assert_called_once_with( model="custom-model", - input="test text" + input=["test text"] ) assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - switch between different models - await processor.on_embeddings("text1", model="model-a") - await processor.on_embeddings("text2", model="model-b") - await processor.on_embeddings("text3", model="model-a") - await processor.on_embeddings("text4") # Use default + await processor.on_embeddings(["text1"], model="model-a") + await processor.on_embeddings(["text2"], model="model-b") + await processor.on_embeddings(["text3"], model="model-a") + await processor.on_embeddings(["text4"]) # Use default # Assert calls = mock_ollama_client.embed.call_args_list @@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text", model=None) + result = await processor.on_embeddings(["test text"], model=None) # Assert mock_ollama_client.embed.assert_called_once_with( model="test-model", - input="test text" + input=["test text"] ) @patch('trustgraph.embeddings.ollama.processor.Client') diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py index 47405431..45a22e48 100644 --- a/tests/unit/test_embeddings/test_row_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): # Mock the flow mock_embeddings_request = AsyncMock() - mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]] + # Return batch of vector sets (one per text) + # 4 unique texts: CUST001, John Doe, CUST002, Jane Smith + mock_embeddings_request.embed.return_value = [ + [[0.1, 0.2, 0.3]], # vectors for text 1 + [[0.2, 0.3, 0.4]], # vectors for text 2 + [[0.3, 0.4, 0.5]], # vectors for text 3 + [[0.4, 0.5, 0.6]], # vectors for text 4 + ] mock_output = AsyncMock() @@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): await processor.on_message(mock_msg, MagicMock(), mock_flow) - # Should have called embed for each unique text - # 4 values: CUST001, John Doe, CUST002, Jane Smith - assert mock_embeddings_request.embed.call_count == 4 + # Should have called embed once with all texts in a batch + assert mock_embeddings_request.embed.call_count == 1 + # Verify it was called with a list of texts + call_args = mock_embeddings_request.embed.call_args + assert 'texts' in call_args.kwargs + assert len(call_args.kwargs['texts']) == 4 # Should have sent output mock_output.send.assert_called() diff --git a/tests/unit/test_extract/test_streaming_triples/__init__.py b/tests/unit/test_extract/test_streaming_triples/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py new file mode 100644 index 00000000..b651b59e --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -0,0 +1,407 @@ +""" +Tests for streaming triple and entity context batching in the definitions +KG extractor. + +Covers: triples batch splitting, entity context batch splitting, +metadata preservation, provenance, and empty/null filtering. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.extract.kg.definitions.extract import ( + Processor, default_triples_batch_size, default_entity_batch_size, +) +from trustgraph.schema import ( + Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(triples_batch_size=default_triples_batch_size, + entity_batch_size=default_entity_batch_size): + proc = Processor.__new__(Processor) + proc.triples_batch_size = triples_batch_size + proc.entity_batch_size = entity_batch_size + return proc + + +def _make_defn(entity, definition): + return {"entity": entity, "definition": definition} + + +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", + user="user-1", collection="col-1", document_id=""): + chunk = Chunk( + metadata=Metadata( + id=meta_id, root=root, user=user, collection=collection, + ), + chunk=text.encode("utf-8"), + document_id=document_id, + ) + msg = MagicMock() + msg.value.return_value = chunk + return msg + + +def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): + mock_triples_pub = AsyncMock() + mock_ecs_pub = AsyncMock() + mock_prompt_client = AsyncMock() + mock_prompt_client.extract_definitions = AsyncMock( + return_value=prompt_result + ) + + def flow(name): + if name == "prompt-request": + return mock_prompt_client + if name == "triples": + return mock_triples_pub + if name == "entity-contexts": + return mock_ecs_pub + if name == "llm-model": + return llm_model + if name == "ontology": + return ontology_uri + return MagicMock() + + return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client + + +def _sent_triples(mock_pub): + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _sent_ecs(mock_pub): + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _all_triples_flat(mock_pub): + result = [] + for triples_msg in _sent_triples(mock_pub): + result.extend(triples_msg.triples) + return result + + +def _all_entities_flat(mock_pub): + result = [] + for ecs_msg in _sent_ecs(mock_pub): + result.extend(ecs_msg.entities) + return result + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDefaults: + + def test_default_triples_batch_size(self): + assert default_triples_batch_size == 50 + + def test_default_entity_batch_size(self): + assert default_entity_batch_size == 5 + + +class TestTriplesBatching: + + @pytest.mark.asyncio + async def test_single_batch_when_under_limit(self): + proc = _make_processor(triples_batch_size=100) + defs = [_make_defn("Cat", "A feline animal")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_triples_batches(self): + proc = _make_processor(triples_batch_size=2) + defs = [ + _make_defn("Cat", "A feline"), + _make_defn("Dog", "A canine"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 2 defs → 2 labels + 2 definitions = 4 triples + provenance + # With batch_size=2, should produce multiple batches + assert triples_pub.send.call_count > 1 + + @pytest.mark.asyncio + async def test_triples_batch_sizes_within_limit(self): + batch_size = 3 + proc = _make_processor(triples_batch_size=batch_size) + defs = [ + _make_defn("A", "def A"), + _make_defn("B", "def B"), + _make_defn("C", "def C"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(triples_pub): + assert len(triples_msg.triples) <= batch_size + + +class TestEntityContextBatching: + + @pytest.mark.asyncio + async def test_single_entity_batch_when_under_limit(self): + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 1 def → 2 entity contexts (name + definition) + assert ecs_pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_entity_batches(self): + proc = _make_processor(entity_batch_size=2) + defs = [ + _make_defn("Cat", "A feline"), + _make_defn("Dog", "A canine"), + ] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 2 defs → 4 entity contexts, batch_size=2 → 2 batches + assert ecs_pub.send.call_count == 2 + + @pytest.mark.asyncio + async def test_entity_batch_sizes_within_limit(self): + batch_size = 3 + proc = _make_processor(entity_batch_size=batch_size) + defs = [ + _make_defn("A", "def A"), + _make_defn("B", "def B"), + _make_defn("C", "def C"), + ] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for ecs_msg in _sent_ecs(ecs_pub): + assert len(ecs_msg.entities) <= batch_size + + @pytest.mark.asyncio + async def test_entity_contexts_have_name_and_definition(self): + """Each definition produces 2 entity contexts: name and definition.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline animal")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + assert len(entities) == 2 + contexts = {e.context for e in entities} + assert "Cat" in contexts + assert "A feline animal" in contexts + + +class TestMetadataPreservation: + + @pytest.mark.asyncio + async def test_triples_metadata(self): + proc = _make_processor(triples_batch_size=2) + defs = [_make_defn("X", "def X")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg( + "text", meta_id="c-1", root="r-1", + user="u-1", collection="coll-1", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(triples_pub): + assert triples_msg.metadata.id == "c-1" + assert triples_msg.metadata.root == "r-1" + assert triples_msg.metadata.user == "u-1" + assert triples_msg.metadata.collection == "coll-1" + + @pytest.mark.asyncio + async def test_entity_contexts_metadata(self): + proc = _make_processor(entity_batch_size=1) + defs = [_make_defn("X", "def X")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg( + "text", meta_id="c-2", root="r-2", + user="u-2", collection="coll-2", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for ecs_msg in _sent_ecs(ecs_pub): + assert ecs_msg.metadata.id == "c-2" + assert ecs_msg.metadata.root == "r-2" + + +class TestEmptyAndNullFiltering: + + @pytest.mark.asyncio + async def test_empty_entity_skipped(self): + proc = _make_processor() + defs = [ + _make_defn("", "some definition"), + _make_defn("Valid", "a valid definition"), + ] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + all_e = _all_entities_flat(ecs_pub) + # Only "Valid" should be present + entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")} + assert any("valid" in iri for iri in entity_iris) + assert len(all_e) == 2 # name + definition for "Valid" only + + @pytest.mark.asyncio + async def test_empty_definition_skipped(self): + proc = _make_processor() + defs = [ + _make_defn("Entity", ""), + _make_defn("Good", "good definition"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")} + assert any("good" in iri for iri in entity_iris) + # "Entity" with empty def should have been skipped + assert not any("entity" in iri and "good" not in iri for iri in entity_iris) + + @pytest.mark.asyncio + async def test_none_fields_skipped(self): + proc = _make_processor() + defs = [ + _make_defn(None, "some definition"), + _make_defn("Entity", None), + ] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_all_filtered_no_output(self): + proc = _make_processor() + defs = [_make_defn("", ""), _make_defn(None, None)] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_prompt_response(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, _ = _make_flow([]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + +class TestProvenanceInclusion: + + @pytest.mark.asyncio + async def test_provenance_triples_present(self): + proc = _make_processor(triples_batch_size=200) + defs = [_make_defn("Cat", "A feline")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + # 1 def → 1 label + 1 definition = 2 content triples + # Provenance adds more + assert len(all_t) > 2 + + +class TestErrorHandling: + + @pytest.mark.asyncio + async def test_prompt_error_caught(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, prompt = _make_flow([]) + prompt.extract_definitions = AsyncMock( + side_effect=RuntimeError("LLM error") + ) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_non_list_response_caught(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, prompt = _make_flow("not a list") + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + +class TestDocumentIdProvenance: + + @pytest.mark.asyncio + async def test_document_id_used_for_chunk_id(self): + """When document_id is set, entity contexts should use it as chunk_id.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text", document_id="doc-123") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + for e in entities: + assert e.chunk_id == "doc-123" + + @pytest.mark.asyncio + async def test_metadata_id_fallback_for_chunk_id(self): + """When document_id is empty, metadata.id is used as chunk_id.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + for e in entities: + assert e.chunk_id == "chunk-42" diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py new file mode 100644 index 00000000..cf3b1fb0 --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -0,0 +1,408 @@ +""" +Tests for streaming triple batching in the relationships KG extractor. + +Covers: batch size configuration, output splitting, metadata preservation, +provenance inclusion, empty/null filtering, and error propagation. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.extract.kg.relationships.extract import ( + Processor, default_triples_batch_size, +) +from trustgraph.schema import ( + Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(triples_batch_size=default_triples_batch_size): + """Create a Processor without triggering FlowProcessor.__init__.""" + proc = Processor.__new__(Processor) + proc.triples_batch_size = triples_batch_size + return proc + + +def _make_rel(subject, predicate, obj, object_entity=True): + """Build a relationship dict as returned by the prompt client.""" + return { + "subject": subject, + "predicate": predicate, + "object": obj, + "object-entity": object_entity, + } + + +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", + user="user-1", collection="col-1", document_id=""): + """Build a mock message wrapping a Chunk.""" + chunk = Chunk( + metadata=Metadata( + id=meta_id, root=root, user=user, collection=collection, + ), + chunk=text.encode("utf-8"), + document_id=document_id, + ) + msg = MagicMock() + msg.value.return_value = chunk + return msg + + +def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): + """Build a mock flow callable that provides prompt client, triples + producer, and parameter specs.""" + mock_triples_pub = AsyncMock() + mock_prompt_client = AsyncMock() + mock_prompt_client.extract_relationships = AsyncMock( + return_value=prompt_result + ) + + def flow(name): + if name == "prompt-request": + return mock_prompt_client + if name == "triples": + return mock_triples_pub + if name == "llm-model": + return llm_model + if name == "ontology": + return ontology_uri + return MagicMock() + + return flow, mock_triples_pub, mock_prompt_client + + +def _sent_triples(mock_pub): + """Collect all Triples objects sent to a mock publisher.""" + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _all_triples_flat(mock_pub): + """Flatten all batches into one list of Triple objects.""" + result = [] + for triples_msg in _sent_triples(mock_pub): + result.extend(triples_msg.triples) + return result + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDefaultBatchSize: + + def test_default_is_50(self): + assert default_triples_batch_size == 50 + + def test_processor_uses_default(self): + proc = _make_processor() + assert proc.triples_batch_size == 50 + + +class TestBatchSplitting: + + @pytest.mark.asyncio + async def test_single_batch_when_under_limit(self): + """Few triples → single send call.""" + proc = _make_processor(triples_batch_size=50) + rels = [_make_rel("A", "knows", "B")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("some text") + + await proc.on_message(msg, MagicMock(), flow) + + # One relationship produces: rel triple + 3 labels + provenance + # All should fit in one batch of 50 + assert pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_batches_with_small_batch_size(self): + """With batch_size=3 and many triples, multiple batches are sent.""" + proc = _make_processor(triples_batch_size=3) + # 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance + rels = [ + _make_rel("A", "knows", "B"), + _make_rel("C", "likes", "D"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("some text") + + await proc.on_message(msg, MagicMock(), flow) + + # Should have more than one batch + assert pub.send.call_count > 1 + + @pytest.mark.asyncio + async def test_batch_sizes_respect_limit(self): + """No batch should exceed the configured batch size.""" + batch_size = 3 + proc = _make_processor(triples_batch_size=batch_size) + rels = [ + _make_rel("A", "knows", "B"), + _make_rel("C", "likes", "D"), + _make_rel("E", "has", "F"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(pub): + assert len(triples_msg.triples) <= batch_size + + @pytest.mark.asyncio + async def test_all_triples_present_across_batches(self): + """Total triples across batches equals expected count.""" + proc = _make_processor(triples_batch_size=2) + # 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples + # + provenance triples + rels = [_make_rel("A", "knows", "B", object_entity=True)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # At minimum: 1 rel + 3 labels = 4 content triples + assert len(all_t) >= 4 + + @pytest.mark.asyncio + async def test_custom_batch_size(self): + """Processor respects custom triples_batch_size parameter.""" + proc = _make_processor(triples_batch_size=100) + assert proc.triples_batch_size == 100 + + +class TestMetadataPreservation: + + @pytest.mark.asyncio + async def test_metadata_forwarded_to_all_batches(self): + """Every batch should carry the original chunk metadata.""" + proc = _make_processor(triples_batch_size=2) + rels = [_make_rel("X", "rel", "Y")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg( + "text", meta_id="c-1", root="r-1", + user="u-1", collection="coll-1", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(pub): + assert triples_msg.metadata.id == "c-1" + assert triples_msg.metadata.root == "r-1" + assert triples_msg.metadata.user == "u-1" + assert triples_msg.metadata.collection == "coll-1" + + +class TestRelationshipTriples: + + @pytest.mark.asyncio + async def test_entity_object_produces_iri(self): + """object-entity=True → object is an IRI, with label triple.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Find the relationship triple (not a label) + rel_triples = [ + t for t in all_t + if t.o.type == IRI and "bob" in t.o.iri + ] + assert len(rel_triples) >= 1 + + @pytest.mark.asyncio + async def test_literal_object_produces_literal(self): + """object-entity=False → object is a LITERAL, no label for object.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "age", "30", object_entity=False)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Find the relationship triple with literal object + lit_triples = [ + t for t in all_t + if t.o.type == LITERAL and t.o.value == "30" + ] + assert len(lit_triples) == 1 + + @pytest.mark.asyncio + async def test_labels_emitted_for_subject_and_predicate(self): + """Every relationship should produce label triples for s and p.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "knows", "Bob")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + label_triples = [ + t for t in all_t + if t.p.type == IRI and "label" in t.p.iri.lower() + ] + labels = {t.o.value for t in label_triples} + assert "Alice" in labels + assert "knows" in labels + assert "Bob" in labels # object-entity default is True + + +class TestEmptyAndNullFiltering: + + @pytest.mark.asyncio + async def test_empty_string_fields_skipped(self): + """Relationships with empty string s/p/o are skipped.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel("", "knows", "Bob"), + _make_rel("Alice", "", "Bob"), + _make_rel("Alice", "knows", ""), + _make_rel("Good", "triple", "Here"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Only the "Good triple Here" relationship should produce content triples + rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri} + assert any("good" in iri for iri in rel_iris) + assert not any("alice" in iri for iri in rel_iris) + + @pytest.mark.asyncio + async def test_none_fields_skipped(self): + """Relationships with None s/p/o are skipped.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel(None, "knows", "Bob"), + _make_rel("Alice", None, "Bob"), + _make_rel("Alice", "knows", None), + _make_rel("Valid", "rel", "Here"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri} + assert any("valid" in iri for iri in rel_iris) + assert not any("alice" in iri for iri in rel_iris) + + @pytest.mark.asyncio + async def test_all_filtered_produces_no_output(self): + """If all relationships are empty/null, nothing is emitted.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel("", "", ""), + _make_rel(None, None, None), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_prompt_response_produces_no_output(self): + """Empty relationship list from prompt → no triples emitted.""" + proc = _make_processor() + flow, pub, _ = _make_flow([]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestProvenanceInclusion: + + @pytest.mark.asyncio + async def test_provenance_triples_present(self): + """Extracted relationships should include provenance triples.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("A", "knows", "B")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Provenance triples use GRAPH_SOURCE graph context + # They contain terms referencing prov: namespace or subgraph URIs + # We just check that total count > 4 (1 rel + 3 labels) + assert len(all_t) > 4 + + @pytest.mark.asyncio + async def test_no_provenance_when_no_extracted_triples(self): + """Empty relationships → no provenance generated.""" + proc = _make_processor() + flow, pub, _ = _make_flow([_make_rel("", "x", "y")]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestErrorPropagation: + + @pytest.mark.asyncio + async def test_prompt_error_is_caught(self): + """Errors from the prompt client are caught (logged, not raised).""" + proc = _make_processor() + flow, pub, prompt = _make_flow([]) + prompt.extract_relationships = AsyncMock( + side_effect=RuntimeError("LLM unavailable") + ) + msg = _make_chunk_msg("text") + + # The outer try/except in on_message catches and logs + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_non_list_response_is_caught(self): + """Non-list prompt response triggers RuntimeError, caught by handler.""" + proc = _make_processor() + flow, pub, prompt = _make_flow("not a list") + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestToUri: + + def test_spaces_replaced_with_hyphens(self): + proc = _make_processor() + uri = proc.to_uri("hello world") + assert "hello-world" in uri + + def test_lowercased(self): + proc = _make_processor() + uri = proc.to_uri("Hello World") + assert "hello-world" in uri + + def test_special_chars_encoded(self): + proc = _make_processor() + # urllib.parse.quote keeps / as safe by default + uri = proc.to_uri("a/b") + assert "a/b" in uri + # Characters like spaces are encoded (handled via replace → hyphen) + uri2 = proc.to_uri("hello world") + assert " " not in uri2 diff --git a/tests/unit/test_gateway/test_rows_import_dispatcher.py b/tests/unit/test_gateway/test_rows_import_dispatcher.py index ab72cae1..f029e9a2 100644 --- a/tests/unit/test_gateway/test_rows_import_dispatcher.py +++ b/tests/unit/test_gateway/test_rows_import_dispatcher.py @@ -55,13 +55,6 @@ def sample_objects_message(): return { "metadata": { "id": "obj-123", - "metadata": [ - { - "s": {"v": "obj-123", "e": False}, - "p": {"v": "source", "e": False}, - "o": {"v": "test", "e": False} - } - ], "user": "testuser", "collection": "testcollection" }, @@ -244,7 +237,6 @@ class TestRowsImportMessageProcessing: assert sent_object.metadata.id == "obj-123" assert sent_object.metadata.user == "testuser" assert sent_object.metadata.collection == "testcollection" - assert len(sent_object.metadata.metadata) == 1 # One triple in metadata @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio @@ -277,7 +269,6 @@ class TestRowsImportMessageProcessing: assert sent_object.values[0]["field1"] == "value1" assert sent_object.confidence == 1.0 # Default value assert sent_object.source_span == "" # Default value - assert len(sent_object.metadata.metadata) == 0 # Default empty list @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio diff --git a/tests/unit/test_gateway/test_streaming_translators.py b/tests/unit/test_gateway/test_streaming_translators.py index e767edd4..e190fe68 100644 --- a/tests/unit/test_gateway/test_streaming_translators.py +++ b/tests/unit/test_gateway/test_streaming_translators.py @@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator: assert is_final is False assert result["end_of_stream"] is False - # Test final chunk with empty content + # Test final message with end_of_session=True final_response = GraphRagResponse( response="", end_of_stream=True, + end_of_session=True, error=None ) # Act result, is_final = translator.from_response_with_completion(final_response) - # Assert + # Assert - is_final is based on end_of_session, not end_of_stream assert is_final is True assert result["response"] == "" - assert result["end_of_stream"] is True + assert result["end_of_session"] is True class TestDocumentRagResponseTranslator: diff --git a/tests/unit/test_knowledge_graph/conftest.py b/tests/unit/test_knowledge_graph/conftest.py index e7f83b58..8e8d9e43 100644 --- a/tests/unit/test_knowledge_graph/conftest.py +++ b/tests/unit/test_knowledge_graph/conftest.py @@ -29,11 +29,11 @@ class Triple: self.o = o class Metadata: - def __init__(self, id, user, collection, metadata): + def __init__(self, id, user, collection, root=""): self.id = id + self.root = root self.user = user self.collection = collection - self.metadata = metadata class Triples: def __init__(self, metadata, triples): @@ -110,7 +110,6 @@ def sample_triples(sample_triple): id="test-doc-123", user="test_user", collection="test_collection", - metadata=[] ) return Triples( @@ -126,7 +125,6 @@ def sample_chunk(): id="test-chunk-456", user="test_user", collection="test_collection", - metadata=[] ) return Chunk( diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction.py b/tests/unit/test_knowledge_graph/test_agent_extraction.py index a3a0f9a7..ec985e3b 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction.py @@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts -from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF +from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL from trustgraph.template.prompt_manager import PromptManager @@ -51,13 +51,6 @@ class TestAgentKgExtractor: """Sample metadata for testing""" return Metadata( id="doc123", - metadata=[ - Triple( - s=Term(type=IRI, iri="doc123"), - p=Term(type=IRI, iri="http://example.org/type"), - o=Term(type=LITERAL, value="document") - ) - ] ) @pytest.fixture @@ -175,7 +168,7 @@ This is not JSON at all } ] - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata) # Check entity label triple label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None) @@ -190,12 +183,6 @@ This is not JSON at all assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" assert def_triple.o.value == "A subset of AI that enables learning from data." - # Check subject-of triple - subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None) - assert subject_of_triple is not None - assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" - assert subject_of_triple.o.iri == "doc123" - # Check entity context assert len(entity_contexts) == 1 assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" @@ -213,7 +200,7 @@ This is not JSON at all } ] - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata) # Check that subject, predicate, and object labels are created subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" @@ -235,10 +222,6 @@ This is not JSON at all assert rel_triple.o.iri == object_uri assert rel_triple.o.type == IRI - # Check subject-of relationships - subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"] - assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations - def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata): """Test processing of relationships with literal objects""" data = [ @@ -251,7 +234,7 @@ This is not JSON at all } ] - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata) # Check that object labels are not created for literal objects object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"] @@ -260,7 +243,7 @@ This is not JSON at all def test_process_extraction_data_combined(self, agent_extractor, sample_metadata, sample_extraction_data): """Test processing of combined definitions and relationships""" - triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata) # Check that we have both definition and relationship triples definition_triples = [t for t in triples if t.p.iri == DEFINITION] @@ -274,16 +257,12 @@ This is not JSON at all def test_process_extraction_data_no_metadata_id(self, agent_extractor): """Test processing when metadata has no ID""" - metadata = Metadata(id=None, metadata=[]) + metadata = Metadata(id=None) data = [ {"type": "definition", "entity": "Test Entity", "definition": "Test definition"} ] - triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata) - - # Should not create subject-of relationships when no metadata ID - subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF] - assert len(subject_of_triples) == 0 + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should still create entity contexts assert len(entity_contexts) == 1 @@ -292,7 +271,7 @@ This is not JSON at all """Test processing of empty extraction data""" data = [] - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata) # Should have no entity contexts assert len(entity_contexts) == 0 @@ -307,7 +286,7 @@ This is not JSON at all {"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True} ] - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata) # Should process valid items and ignore unknown types assert len(entity_contexts) == 1 # Only the definition creates entity context @@ -345,8 +324,6 @@ This is not JSON at all assert sent_triples.metadata.id == sample_metadata.id assert sent_triples.metadata.user == sample_metadata.user assert sent_triples.metadata.collection == sample_metadata.collection - # Note: metadata.metadata is now empty array in the new implementation - assert sent_triples.metadata.metadata == [] assert len(sent_triples.triples) == 1 assert sent_triples.triples[0].s.iri == "test:subject" @@ -371,8 +348,6 @@ This is not JSON at all assert sent_contexts.metadata.id == sample_metadata.id assert sent_contexts.metadata.user == sample_metadata.user assert sent_contexts.metadata.collection == sample_metadata.collection - # Note: metadata.metadata is now empty array in the new implementation - assert sent_contexts.metadata.metadata == [] assert len(sent_contexts.entities) == 1 assert sent_contexts.entities[0].entity.iri == "test:entity" diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py b/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py index f66e5da6..b0be3f06 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py @@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts -from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF +from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL @pytest.mark.unit @@ -168,7 +168,7 @@ class TestAgentKgExtractionEdgeCases: """Test processing with empty or minimal metadata""" # Test with None metadata - may not raise AttributeError depending on implementation try: - triples, contexts = agent_extractor.process_extraction_data([], None) + triples, contexts, _ = agent_extractor.process_extraction_data([], None) # If it doesn't raise, check the results assert len(triples) == 0 assert len(contexts) == 0 @@ -177,23 +177,19 @@ class TestAgentKgExtractionEdgeCases: pass # Test with metadata without ID - metadata = Metadata(id=None, metadata=[]) - triples, contexts = agent_extractor.process_extraction_data([], metadata) + metadata = Metadata(id=None) + triples, contexts, _ = agent_extractor.process_extraction_data([], metadata) assert len(triples) == 0 assert len(contexts) == 0 # Test with metadata with empty string ID - metadata = Metadata(id="", metadata=[]) + metadata = Metadata(id="") data = [{"type": "definition", "entity": "Test", "definition": "Test def"}] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) - - # Should not create subject-of triples when ID is empty string - subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF] - assert len(subject_of_triples) == 0 + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) def test_process_extraction_data_special_entity_names(self, agent_extractor): """Test processing with special characters in entity names""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") special_entities = [ "Entity with spaces", @@ -213,7 +209,7 @@ class TestAgentKgExtractionEdgeCases: for entity in special_entities ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Verify all entities were processed assert len(contexts) == len(special_entities) @@ -225,7 +221,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_very_long_definitions(self, agent_extractor): """Test processing with very long entity definitions""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") # Create very long definition long_definition = "This is a very long definition. " * 1000 @@ -234,7 +230,7 @@ class TestAgentKgExtractionEdgeCases: {"type": "definition", "entity": "Test Entity", "definition": long_definition} ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should handle long definitions without issues assert len(contexts) == 1 @@ -247,7 +243,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_duplicate_entities(self, agent_extractor): """Test processing with duplicate entity names""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") data = [ {"type": "definition", "entity": "Machine Learning", "definition": "First definition"}, @@ -256,7 +252,7 @@ class TestAgentKgExtractionEdgeCases: {"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should process all entries (including duplicates) assert len(contexts) == 4 @@ -269,7 +265,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_empty_strings(self, agent_extractor): """Test processing with empty strings in data""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") data = [ {"type": "definition", "entity": "", "definition": "Definition for empty entity"}, @@ -280,7 +276,7 @@ class TestAgentKgExtractionEdgeCases: {"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True}, ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should handle empty strings by creating URIs (even if empty) assert len(contexts) == 3 @@ -291,7 +287,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_nested_json_in_strings(self, agent_extractor): """Test processing when definitions contain JSON-like strings""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") data = [ { @@ -306,7 +302,7 @@ class TestAgentKgExtractionEdgeCases: } ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should handle JSON strings in definitions without parsing them assert len(contexts) == 2 @@ -315,7 +311,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor): """Test processing with various boolean values for object-entity""" - metadata = Metadata(id="doc123", metadata=[]) + metadata = Metadata(id="doc123") data = [ # Explicit True @@ -334,16 +330,16 @@ class TestAgentKgExtractionEdgeCases: {"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1}, ] - triples, contexts = agent_extractor.process_extraction_data(data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata) # Should process all relationships # Note: The current implementation has some logic issues that these tests document - assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7 + assert len([t for t in triples if t.p.iri != RDF_LABEL]) >= 7 @pytest.mark.asyncio async def test_emit_empty_collections(self, agent_extractor): """Test emitting empty triples and entity contexts""" - metadata = Metadata(id="test", metadata=[]) + metadata = Metadata(id="test") # Test emitting empty triples mock_publisher = AsyncMock() @@ -389,7 +385,7 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_performance_large_dataset(self, agent_extractor): """Test performance with large extraction datasets""" - metadata = Metadata(id="large-doc", metadata=[]) + metadata = Metadata(id="large-doc") # Create large dataset in JSONL format num_definitions = 1000 @@ -416,7 +412,7 @@ class TestAgentKgExtractionEdgeCases: import time start_time = time.time() - triples, contexts = agent_extractor.process_extraction_data(large_data, metadata) + triples, contexts, _ = agent_extractor.process_extraction_data(large_data, metadata) end_time = time.time() processing_time = end_time - start_time diff --git a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py index 525f595d..f82e4cc8 100644 --- a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py +++ b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py @@ -314,7 +314,6 @@ class TestObjectExtractionBusinessLogic: id="test-extraction-001", user="test_user", collection="test_collection", - metadata=[] ) values = [{ diff --git a/tests/unit/test_knowledge_graph/test_triple_construction.py b/tests/unit/test_knowledge_graph/test_triple_construction.py index 10bae2e7..e45c69aa 100644 --- a/tests/unit/test_knowledge_graph/test_triple_construction.py +++ b/tests/unit/test_knowledge_graph/test_triple_construction.py @@ -373,7 +373,6 @@ class TestTripleConstructionLogic: id="test-doc-123", user="test_user", collection="test_collection", - metadata=[] ) # Act diff --git a/tests/unit/test_librarian/__init__.py b/tests/unit/test_librarian/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py new file mode 100644 index 00000000..2d09bcd4 --- /dev/null +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -0,0 +1,716 @@ +""" +Tests for librarian chunked upload operations: +begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status, +list_uploads, and stream_document. +""" + +import base64 +import json +import math +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE +from trustgraph.exceptions import RequestError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_librarian(min_chunk_size=1): + """Create a Librarian with mocked blob_store and table_store.""" + lib = Librarian.__new__(Librarian) + lib.blob_store = MagicMock() + lib.table_store = AsyncMock() + lib.load_document = AsyncMock() + lib.min_chunk_size = min_chunk_size + return lib + + +def _make_doc_metadata( + doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc" +): + meta = MagicMock() + meta.id = doc_id + meta.kind = kind + meta.user = user + meta.title = title + meta.time = 1700000000 + meta.comments = "" + meta.tags = [] + return meta + + +def _make_begin_request( + doc_id="doc-1", kind="application/pdf", user="alice", + total_size=10_000_000, chunk_size=0 +): + req = MagicMock() + req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user) + req.total_size = total_size + req.chunk_size = chunk_size + return req + + +def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"): + req = MagicMock() + req.upload_id = upload_id + req.chunk_index = chunk_index + req.user = user + req.content = base64.b64encode(content) + return req + + +def _make_session( + user="alice", total_chunks=5, chunk_size=2_000_000, + total_size=10_000_000, chunks_received=None, object_id="obj-1", + s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1", +): + if chunks_received is None: + chunks_received = {} + if document_metadata is None: + document_metadata = json.dumps({ + "id": document_id, "kind": "application/pdf", + "user": user, "title": "Test", "time": 1700000000, + "comments": "", "tags": [], + }) + return { + "user": user, + "total_chunks": total_chunks, + "chunk_size": chunk_size, + "total_size": total_size, + "chunks_received": chunks_received, + "object_id": object_id, + "s3_upload_id": s3_upload_id, + "document_metadata": document_metadata, + "document_id": document_id, + } + + +# --------------------------------------------------------------------------- +# begin_upload +# --------------------------------------------------------------------------- + +class TestBeginUpload: + + @pytest.mark.asyncio + async def test_creates_session(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-upload-id" + + req = _make_begin_request(total_size=10_000_000) + resp = await lib.begin_upload(req) + + assert resp.error is None + assert resp.upload_id is not None + assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE) + assert resp.chunk_size == DEFAULT_CHUNK_SIZE + + @pytest.mark.asyncio + async def test_custom_chunk_size(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(total_size=10_000, chunk_size=3000) + resp = await lib.begin_upload(req) + + assert resp.chunk_size == 3000 + assert resp.total_chunks == math.ceil(10_000 / 3000) + + @pytest.mark.asyncio + async def test_rejects_invalid_kind(self): + lib = _make_librarian() + req = _make_begin_request(kind="image/png") + + with pytest.raises(RequestError, match="Invalid document kind"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_duplicate_document(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = True + + req = _make_begin_request() + with pytest.raises(RequestError, match="already exists"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_zero_size(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + + req = _make_begin_request(total_size=0) + with pytest.raises(RequestError, match="positive"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_chunk_below_minimum(self): + lib = _make_librarian(min_chunk_size=1024) + lib.table_store.document_exists.return_value = False + + req = _make_begin_request(total_size=10_000, chunk_size=512) + with pytest.raises(RequestError, match="below minimum"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_calls_s3_create_multipart(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(kind="application/pdf") + await lib.begin_upload(req) + + lib.blob_store.create_multipart_upload.assert_called_once() + # create_multipart_upload(object_id, kind) — positional args + args = lib.blob_store.create_multipart_upload.call_args[0] + assert args[1] == "application/pdf" + + @pytest.mark.asyncio + async def test_stores_session_in_cassandra(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(total_size=5_000_000) + resp = await lib.begin_upload(req) + + lib.table_store.create_upload_session.assert_called_once() + kwargs = lib.table_store.create_upload_session.call_args[1] + assert kwargs["upload_id"] == resp.upload_id + assert kwargs["total_size"] == 5_000_000 + assert kwargs["total_chunks"] == resp.total_chunks + + @pytest.mark.asyncio + async def test_accepts_text_plain(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(kind="text/plain", total_size=1000) + resp = await lib.begin_upload(req) + assert resp.error is None + + +# --------------------------------------------------------------------------- +# upload_chunk +# --------------------------------------------------------------------------- + +class TestUploadChunk: + + @pytest.mark.asyncio + async def test_successful_chunk_upload(self): + lib = _make_librarian() + session = _make_session(total_chunks=5, chunks_received={}) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag-1" + + req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data") + resp = await lib.upload_chunk(req) + + assert resp.error is None + assert resp.chunk_index == 0 + assert resp.total_chunks == 5 + # The chunk is added to the dict (len=1), then +1 applied => 2 + assert resp.chunks_received == 2 + + @pytest.mark.asyncio + async def test_s3_part_number_is_1_indexed(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + req = _make_upload_chunk_request(chunk_index=0) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part + + @pytest.mark.asyncio + async def test_chunk_index_3_becomes_part_4(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + req = _make_upload_chunk_request(chunk_index=3) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["part_number"] == 4 + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = _make_upload_chunk_request() + with pytest.raises(RequestError, match="not found"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(user="bob") + with pytest.raises(RequestError, match="Not authorized"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_negative_chunk_index(self): + lib = _make_librarian() + session = _make_session(total_chunks=5) + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(chunk_index=-1) + with pytest.raises(RequestError, match="Invalid chunk index"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_out_of_range_chunk_index(self): + lib = _make_librarian() + session = _make_session(total_chunks=5) + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(chunk_index=5) + with pytest.raises(RequestError, match="Invalid chunk index"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_progress_tracking(self): + lib = _make_librarian() + session = _make_session( + total_chunks=4, chunk_size=1000, total_size=3500, + chunks_received={0: "e1", 1: "e2"}, + ) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "e3" + + req = _make_upload_chunk_request(chunk_index=2) + resp = await lib.upload_chunk(req) + + # Dict gets chunk 2 added (len=3), then +1 => 4 + assert resp.chunks_received == 4 + assert resp.total_chunks == 4 + assert resp.total_bytes == 3500 + + @pytest.mark.asyncio + async def test_bytes_capped_at_total_size(self): + """bytes_received should not exceed total_size for the final chunk.""" + lib = _make_librarian() + session = _make_session( + total_chunks=2, chunk_size=3000, total_size=5000, + chunks_received={0: "e1"}, + ) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "e2" + + req = _make_upload_chunk_request(chunk_index=1) + resp = await lib.upload_chunk(req) + + # 3 chunks × 3000 = 9000 > 5000, so capped + assert resp.bytes_received <= 5000 + + @pytest.mark.asyncio + async def test_base64_decodes_content(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + raw = b"hello world binary data" + req = _make_upload_chunk_request(content=raw) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["data"] == raw + + +# --------------------------------------------------------------------------- +# complete_upload +# --------------------------------------------------------------------------- + +class TestCompleteUpload: + + @pytest.mark.asyncio + async def test_successful_completion(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, + chunks_received={0: "e1", 1: "e2", 2: "e3"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.complete_upload(req) + + assert resp.error is None + assert resp.document_id == "doc-1" + lib.blob_store.complete_multipart_upload.assert_called_once() + lib.table_store.add_document.assert_called_once() + lib.table_store.delete_upload_session.assert_called_once_with("up-1") + + @pytest.mark.asyncio + async def test_parts_sorted_by_index(self): + lib = _make_librarian() + # Chunks received out of order + session = _make_session( + total_chunks=3, + chunks_received={2: "e3", 0: "e1", 1: "e2"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + await lib.complete_upload(req) + + parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"] + part_numbers = [p[0] for p in parts] + assert part_numbers == [1, 2, 3] # Sorted, 1-indexed + + @pytest.mark.asyncio + async def test_rejects_missing_chunks(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, + chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + with pytest.raises(RequestError, match="Missing chunks"): + await lib.complete_upload(req) + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-gone" + req.user = "alice" + + with pytest.raises(RequestError, match="not found"): + await lib.complete_upload(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.complete_upload(req) + + +# --------------------------------------------------------------------------- +# abort_upload +# --------------------------------------------------------------------------- + +class TestAbortUpload: + + @pytest.mark.asyncio + async def test_aborts_and_cleans_up(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.abort_upload(req) + + assert resp.error is None + lib.blob_store.abort_multipart_upload.assert_called_once_with( + object_id="obj-1", upload_id="s3-up-1" + ) + lib.table_store.delete_upload_session.assert_called_once_with("up-1") + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-gone" + req.user = "alice" + + with pytest.raises(RequestError, match="not found"): + await lib.abort_upload(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.abort_upload(req) + + +# --------------------------------------------------------------------------- +# get_upload_status +# --------------------------------------------------------------------------- + +class TestGetUploadStatus: + + @pytest.mark.asyncio + async def test_in_progress_status(self): + lib = _make_librarian() + session = _make_session( + total_chunks=5, chunk_size=2000, total_size=10_000, + chunks_received={0: "e1", 2: "e3", 4: "e5"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.upload_state == "in-progress" + assert resp.chunks_received == 3 + assert resp.total_chunks == 5 + assert sorted(resp.received_chunks) == [0, 2, 4] + assert sorted(resp.missing_chunks) == [1, 3] + assert resp.total_bytes == 10_000 + + @pytest.mark.asyncio + async def test_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-expired" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.upload_state == "expired" + + @pytest.mark.asyncio + async def test_all_chunks_received(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, chunk_size=1000, total_size=2500, + chunks_received={0: "e1", 1: "e2", 2: "e3"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.missing_chunks == [] + assert resp.chunks_received == 3 + # 3 * 1000 = 3000 > 2500, so capped + assert resp.bytes_received <= 2500 + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.get_upload_status(req) + + +# --------------------------------------------------------------------------- +# stream_document +# --------------------------------------------------------------------------- + +class TestStreamDocument: + + @pytest.mark.asyncio + async def test_streams_chunks_with_progress(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 3 # ceil(5000/2000) + assert chunks[0].chunk_index == 0 + assert chunks[0].total_chunks == 3 + assert chunks[0].is_final is False + assert chunks[-1].is_final is True + assert chunks[-1].chunk_index == 2 + + @pytest.mark.asyncio + async def test_single_chunk_document(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=500) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 1 + assert chunks[0].is_final is True + assert chunks[0].bytes_received == 500 + assert chunks[0].total_bytes == 500 + + @pytest.mark.asyncio + async def test_byte_ranges_correct(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + # Verify the byte ranges passed to get_range + calls = lib.blob_store.get_range.call_args_list + assert calls[0][0] == ("obj-1", 0, 2000) + assert calls[1][0] == ("obj-1", 2000, 2000) + assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000 + + @pytest.mark.asyncio + async def test_default_chunk_size(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=2_000_000) + lib.blob_store.get_range = AsyncMock(return_value=b"x") + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 0 # Should use default 1MB + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 2 # ceil(2MB / 1MB) + + @pytest.mark.asyncio + async def test_content_is_base64_encoded(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=100) + raw = b"hello world" + lib.blob_store.get_range = AsyncMock(return_value=raw) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 1000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert chunks[0].content == base64.b64encode(raw) + + @pytest.mark.asyncio + async def test_rejects_chunk_below_minimum(self): + lib = _make_librarian(min_chunk_size=1024) + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 512 + + with pytest.raises(RequestError, match="below minimum"): + async for _ in lib.stream_document(req): + pass + + +# --------------------------------------------------------------------------- +# list_uploads +# --------------------------------------------------------------------------- + +class TestListUploads: + + @pytest.mark.asyncio + async def test_returns_sessions(self): + lib = _make_librarian() + lib.table_store.list_upload_sessions.return_value = [ + { + "upload_id": "up-1", + "document_id": "doc-1", + "document_metadata": '{"id":"doc-1"}', + "total_size": 10000, + "chunk_size": 2000, + "total_chunks": 5, + "chunks_received": {0: "e1", 1: "e2"}, + "created_at": "2024-01-01", + }, + ] + + req = MagicMock() + req.user = "alice" + + resp = await lib.list_uploads(req) + + assert resp.error is None + assert len(resp.upload_sessions) == 1 + assert resp.upload_sessions[0].upload_id == "up-1" + assert resp.upload_sessions[0].total_chunks == 5 + + @pytest.mark.asyncio + async def test_empty_uploads(self): + lib = _make_librarian() + lib.table_store.list_upload_sessions.return_value = [] + + req = MagicMock() + req.user = "alice" + + resp = await lib.list_uploads(req) + + assert resp.upload_sessions == [] diff --git a/tests/unit/test_provenance/__init__.py b/tests/unit/test_provenance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py new file mode 100644 index 00000000..4efe24c7 --- /dev/null +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -0,0 +1,336 @@ +""" +Tests for agent provenance triple builder functions. +""" + +import json +import pytest + +from trustgraph.schema import Triple, Term, IRI, LITERAL + +from trustgraph.provenance.agent import ( + agent_session_triples, + agent_iteration_triples, + agent_final_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_AGENT_QUESTION, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + for t in triples: + if (t.s.iri == subject and t.p.iri == RDF_TYPE + and t.o.type == IRI and t.o.iri == rdf_type): + return True + return False + + +# --------------------------------------------------------------------------- +# agent_session_triples +# --------------------------------------------------------------------------- + +class TestAgentSessionTriples: + + SESSION_URI = "urn:trustgraph:agent:test-session" + + def test_session_types(self): + triples = agent_session_triples( + self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z" + ) + assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY) + assert has_type(triples, self.SESSION_URI, TG_QUESTION) + assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION) + + def test_session_query_text(self): + triples = agent_session_triples( + self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z" + ) + query = find_triple(triples, TG_QUERY, self.SESSION_URI) + assert query is not None + assert query.o.value == "What is X?" + + def test_session_timestamp(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-06-15T10:00:00Z" + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI) + assert ts is not None + assert ts.o.value == "2024-06-15T10:00:00Z" + + def test_session_default_timestamp(self): + triples = agent_session_triples(self.SESSION_URI, "Q") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_session_label(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + label = find_triple(triples, RDFS_LABEL, self.SESSION_URI) + assert label is not None + assert label.o.value == "Agent Question" + + def test_session_triple_count(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + assert len(triples) == 6 + + +# --------------------------------------------------------------------------- +# agent_iteration_triples +# --------------------------------------------------------------------------- + +class TestAgentIterationTriples: + + ITER_URI = "urn:trustgraph:agent:test-session/i1" + SESSION_URI = "urn:trustgraph:agent:test-session" + PREV_URI = "urn:trustgraph:agent:test-session/i0" + + def test_iteration_types(self): + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + ) + assert has_type(triples, self.ITER_URI, PROV_ENTITY) + assert has_type(triples, self.ITER_URI, TG_ANALYSIS) + + def test_first_iteration_generated_by_question(self): + """First iteration uses wasGeneratedBy to link to question activity.""" + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) + assert gen is not None + assert gen.o.iri == self.SESSION_URI + # Should NOT have wasDerivedFrom + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) + assert derived is None + + def test_subsequent_iteration_derived_from_previous(self): + """Subsequent iterations use wasDerivedFrom to link to previous iteration.""" + triples = agent_iteration_triples( + self.ITER_URI, previous_uri=self.PREV_URI, + action="search", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) + assert derived is not None + assert derived.o.iri == self.PREV_URI + # Should NOT have wasGeneratedBy + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) + assert gen is None + + def test_iteration_label_includes_action(self): + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="graph-rag-query", + ) + label = find_triple(triples, RDFS_LABEL, self.ITER_URI) + assert label is not None + assert "graph-rag-query" in label.o.value + + def test_iteration_thought_sub_entity(self): + """Thought is a sub-entity with Reflection and Thought types.""" + thought_uri = "urn:trustgraph:agent:test-session/i1/thought" + thought_doc = "urn:doc:thought-1" + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + thought_uri=thought_uri, + thought_document_id=thought_doc, + ) + # Iteration links to thought sub-entity + thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI) + assert thought_link is not None + assert thought_link.o.iri == thought_uri + # Thought has correct types + assert has_type(triples, thought_uri, TG_REFLECTION_TYPE) + assert has_type(triples, thought_uri, TG_THOUGHT_TYPE) + # Thought was generated by iteration + gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri) + assert gen is not None + assert gen.o.iri == self.ITER_URI + # Thought has document reference + doc = find_triple(triples, TG_DOCUMENT, thought_uri) + assert doc is not None + assert doc.o.iri == thought_doc + + def test_iteration_observation_sub_entity(self): + """Observation is a sub-entity with Reflection and Observation types.""" + obs_uri = "urn:trustgraph:agent:test-session/i1/observation" + obs_doc = "urn:doc:obs-1" + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + observation_uri=obs_uri, + observation_document_id=obs_doc, + ) + # Iteration links to observation sub-entity + obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert obs_link is not None + assert obs_link.o.iri == obs_uri + # Observation has correct types + assert has_type(triples, obs_uri, TG_REFLECTION_TYPE) + assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE) + # Observation was generated by iteration + gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri) + assert gen is not None + assert gen.o.iri == self.ITER_URI + # Observation has document reference + doc = find_triple(triples, TG_DOCUMENT, obs_uri) + assert doc is not None + assert doc.o.iri == obs_doc + + def test_iteration_action_recorded(self): + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="graph-rag-query", + ) + action = find_triple(triples, TG_ACTION, self.ITER_URI) + assert action is not None + assert action.o.value == "graph-rag-query" + + def test_iteration_arguments_json_encoded(self): + args = {"query": "test query", "limit": 10} + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + arguments=args, + ) + arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) + assert arguments is not None + parsed = json.loads(arguments.o.value) + assert parsed == args + + def test_iteration_default_arguments_empty_dict(self): + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + ) + arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) + assert arguments is not None + parsed = json.loads(arguments.o.value) + assert parsed == {} + + def test_iteration_no_thought_or_observation(self): + """Minimal iteration with just action — no thought or observation triples.""" + triples = agent_iteration_triples( + self.ITER_URI, question_uri=self.SESSION_URI, + action="noop", + ) + thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) + obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert thought is None + assert obs is None + + def test_iteration_chaining(self): + """First iteration uses wasGeneratedBy, second uses wasDerivedFrom.""" + iter1_uri = "urn:trustgraph:agent:sess/i1" + iter2_uri = "urn:trustgraph:agent:sess/i2" + + triples1 = agent_iteration_triples( + iter1_uri, question_uri=self.SESSION_URI, action="step1", + ) + triples2 = agent_iteration_triples( + iter2_uri, previous_uri=iter1_uri, action="step2", + ) + + gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri) + assert gen1.o.iri == self.SESSION_URI + + derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri) + assert derived2.o.iri == iter1_uri + + +# --------------------------------------------------------------------------- +# agent_final_triples +# --------------------------------------------------------------------------- + +class TestAgentFinalTriples: + + FINAL_URI = "urn:trustgraph:agent:test-session/final" + PREV_URI = "urn:trustgraph:agent:test-session/i3" + SESSION_URI = "urn:trustgraph:agent:test-session" + + def test_final_types(self): + triples = agent_final_triples( + self.FINAL_URI, previous_uri=self.PREV_URI, + ) + assert has_type(triples, self.FINAL_URI, PROV_ENTITY) + assert has_type(triples, self.FINAL_URI, TG_CONCLUSION) + assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE) + + def test_final_derived_from_previous(self): + """Conclusion with iterations uses wasDerivedFrom.""" + triples = agent_final_triples( + self.FINAL_URI, previous_uri=self.PREV_URI, + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) + assert derived is not None + assert derived.o.iri == self.PREV_URI + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) + assert gen is None + + def test_final_generated_by_question_when_no_iterations(self): + """When agent answers immediately, final uses wasGeneratedBy.""" + triples = agent_final_triples( + self.FINAL_URI, question_uri=self.SESSION_URI, + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) + assert gen is not None + assert gen.o.iri == self.SESSION_URI + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) + assert derived is None + + def test_final_label(self): + triples = agent_final_triples( + self.FINAL_URI, previous_uri=self.PREV_URI, + ) + label = find_triple(triples, RDFS_LABEL, self.FINAL_URI) + assert label is not None + assert label.o.value == "Conclusion" + + def test_final_document_reference(self): + triples = agent_final_triples( + self.FINAL_URI, previous_uri=self.PREV_URI, + document_id="urn:trustgraph:agent:sess/answer", + ) + doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) + assert doc is not None + assert doc.o.type == IRI + assert doc.o.iri == "urn:trustgraph:agent:sess/answer" + + def test_final_no_document(self): + triples = agent_final_triples( + self.FINAL_URI, previous_uri=self.PREV_URI, + ) + doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) + assert doc is None diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py new file mode 100644 index 00000000..62498c61 --- /dev/null +++ b/tests/unit/test_provenance/test_explainability.py @@ -0,0 +1,543 @@ +""" +Tests for the explainability API (entity parsing, wire format conversion, +and ExplainabilityClient). +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.api.explainability import ( + EdgeSelection, + ExplainEntity, + Question, + Grounding, + Exploration, + Focus, + Synthesis, + Reflection, + Analysis, + Conclusion, + parse_edge_selection_triples, + extract_term_value, + wire_triples_to_tuples, + ExplainabilityClient, + TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY, + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANALYSIS, TG_CONCLUSION, + TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, + PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + RDF_TYPE, RDFS_LABEL, +) + + +# --------------------------------------------------------------------------- +# Entity from_triples parsing +# --------------------------------------------------------------------------- + +class TestExplainEntityFromTriples: + """Test ExplainEntity.from_triples dispatches to correct subclass.""" + + def test_graphrag_question(self): + triples = [ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "What is AI?"), + ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"), + ] + entity = ExplainEntity.from_triples("urn:q:1", triples) + assert isinstance(entity, Question) + assert entity.query == "What is AI?" + assert entity.timestamp == "2024-01-01T00:00:00Z" + assert entity.question_type == "graph-rag" + + def test_docrag_question(self): + triples = [ + ("urn:q:2", RDF_TYPE, TG_QUESTION), + ("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION), + ("urn:q:2", TG_QUERY, "Find info"), + ] + entity = ExplainEntity.from_triples("urn:q:2", triples) + assert isinstance(entity, Question) + assert entity.question_type == "document-rag" + + def test_agent_question(self): + triples = [ + ("urn:q:3", RDF_TYPE, TG_QUESTION), + ("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION), + ("urn:q:3", TG_QUERY, "Agent query"), + ] + entity = ExplainEntity.from_triples("urn:q:3", triples) + assert isinstance(entity, Question) + assert entity.question_type == "agent" + + def test_grounding(self): + triples = [ + ("urn:gnd:1", RDF_TYPE, TG_GROUNDING), + ("urn:gnd:1", TG_CONCEPT, "machine learning"), + ("urn:gnd:1", TG_CONCEPT, "neural networks"), + ] + entity = ExplainEntity.from_triples("urn:gnd:1", triples) + assert isinstance(entity, Grounding) + assert len(entity.concepts) == 2 + assert "machine learning" in entity.concepts + assert "neural networks" in entity.concepts + + def test_exploration(self): + triples = [ + ("urn:exp:1", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:1", TG_EDGE_COUNT, "15"), + ] + entity = ExplainEntity.from_triples("urn:exp:1", triples) + assert isinstance(entity, Exploration) + assert entity.edge_count == 15 + + def test_exploration_with_chunk_count(self): + triples = [ + ("urn:exp:2", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:2", TG_CHUNK_COUNT, "5"), + ] + entity = ExplainEntity.from_triples("urn:exp:2", triples) + assert isinstance(entity, Exploration) + assert entity.chunk_count == 5 + + def test_exploration_with_entities(self): + triples = [ + ("urn:exp:3", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:3", TG_EDGE_COUNT, "10"), + ("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"), + ("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"), + ] + entity = ExplainEntity.from_triples("urn:exp:3", triples) + assert isinstance(entity, Exploration) + assert len(entity.entities) == 2 + + def test_exploration_invalid_count(self): + triples = [ + ("urn:exp:3", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:3", TG_EDGE_COUNT, "not-a-number"), + ] + entity = ExplainEntity.from_triples("urn:exp:3", triples) + assert isinstance(entity, Exploration) + assert entity.edge_count == 0 + + def test_focus(self): + triples = [ + ("urn:foc:1", RDF_TYPE, TG_FOCUS), + ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"), + ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"), + ] + entity = ExplainEntity.from_triples("urn:foc:1", triples) + assert isinstance(entity, Focus) + assert len(entity.selected_edge_uris) == 2 + assert "urn:edge:1" in entity.selected_edge_uris + assert "urn:edge:2" in entity.selected_edge_uris + + def test_synthesis_with_document(self): + triples = [ + ("urn:syn:1", RDF_TYPE, TG_SYNTHESIS), + ("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"), + ] + entity = ExplainEntity.from_triples("urn:syn:1", triples) + assert isinstance(entity, Synthesis) + assert entity.document == "urn:doc:answer-1" + + def test_synthesis_no_document(self): + triples = [ + ("urn:syn:2", RDF_TYPE, TG_SYNTHESIS), + ] + entity = ExplainEntity.from_triples("urn:syn:2", triples) + assert isinstance(entity, Synthesis) + assert entity.document == "" + + def test_reflection_thought(self): + triples = [ + ("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE), + ("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE), + ("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"), + ] + entity = ExplainEntity.from_triples("urn:ref:1", triples) + assert isinstance(entity, Reflection) + assert entity.reflection_type == "thought" + assert entity.document == "urn:doc:thought-1" + + def test_reflection_observation(self): + triples = [ + ("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE), + ("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE), + ("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"), + ] + entity = ExplainEntity.from_triples("urn:ref:2", triples) + assert isinstance(entity, Reflection) + assert entity.reflection_type == "observation" + assert entity.document == "urn:doc:obs-1" + + def test_analysis(self): + triples = [ + ("urn:ana:1", RDF_TYPE, TG_ANALYSIS), + ("urn:ana:1", TG_ACTION, "graph-rag-query"), + ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'), + ("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"), + ("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"), + ] + entity = ExplainEntity.from_triples("urn:ana:1", triples) + assert isinstance(entity, Analysis) + assert entity.action == "graph-rag-query" + assert entity.arguments == '{"query": "test"}' + assert entity.thought == "urn:ref:thought-1" + assert entity.observation == "urn:ref:obs-1" + + def test_conclusion_with_document(self): + triples = [ + ("urn:conc:1", RDF_TYPE, TG_CONCLUSION), + ("urn:conc:1", TG_DOCUMENT, "urn:doc:final"), + ] + entity = ExplainEntity.from_triples("urn:conc:1", triples) + assert isinstance(entity, Conclusion) + assert entity.document == "urn:doc:final" + + def test_conclusion_no_document(self): + triples = [ + ("urn:conc:2", RDF_TYPE, TG_CONCLUSION), + ] + entity = ExplainEntity.from_triples("urn:conc:2", triples) + assert isinstance(entity, Conclusion) + assert entity.document == "" + + def test_unknown_type(self): + triples = [ + ("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"), + ] + entity = ExplainEntity.from_triples("urn:x:1", triples) + assert isinstance(entity, ExplainEntity) + assert entity.entity_type == "unknown" + + +# --------------------------------------------------------------------------- +# parse_edge_selection_triples +# --------------------------------------------------------------------------- + +class TestParseEdgeSelectionTriples: + + def test_with_edge_and_reasoning(self): + triples = [ + ("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}), + ("urn:edge:1", TG_REASONING, "Alice and Bob are connected"), + ] + result = parse_edge_selection_triples(triples) + assert isinstance(result, EdgeSelection) + assert result.uri == "urn:edge:1" + assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"} + assert result.reasoning == "Alice and Bob are connected" + + def test_with_edge_only(self): + triples = [ + ("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is not None + assert result.reasoning == "" + + def test_with_reasoning_only(self): + triples = [ + ("urn:edge:3", TG_REASONING, "some reason"), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is None + assert result.reasoning == "some reason" + + def test_empty_triples(self): + result = parse_edge_selection_triples([]) + assert result.uri == "" + assert result.edge is None + assert result.reasoning == "" + + def test_edge_must_be_dict(self): + """Non-dict values for TG_EDGE should not be treated as edges.""" + triples = [ + ("urn:edge:4", TG_EDGE, "not-a-dict"), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is None + + +# --------------------------------------------------------------------------- +# extract_term_value +# --------------------------------------------------------------------------- + +class TestExtractTermValue: + + def test_iri_short_format(self): + assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test" + + def test_iri_long_format(self): + assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test" + + def test_literal_short_format(self): + assert extract_term_value({"t": "l", "v": "hello"}) == "hello" + + def test_literal_long_format(self): + assert extract_term_value({"type": "l", "value": "hello"}) == "hello" + + def test_quoted_triple(self): + term = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "urn:s"}, + "p": {"t": "i", "i": "urn:p"}, + "o": {"t": "i", "i": "urn:o"}, + } + } + result = extract_term_value(term) + assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"} + + def test_quoted_triple_long_format(self): + term = { + "type": "t", + "triple": { + "s": {"type": "i", "iri": "urn:s"}, + "p": {"type": "i", "iri": "urn:p"}, + "o": {"type": "l", "value": "val"}, + } + } + result = extract_term_value(term) + assert result == {"s": "urn:s", "p": "urn:p", "o": "val"} + + def test_unknown_type_fallback(self): + result = extract_term_value({"t": "x", "i": "urn:fallback"}) + assert result == "urn:fallback" + + +# --------------------------------------------------------------------------- +# wire_triples_to_tuples +# --------------------------------------------------------------------------- + +class TestWireTriplesToTuples: + + def test_basic_conversion(self): + wire = [ + { + "s": {"t": "i", "i": "urn:s1"}, + "p": {"t": "i", "i": "urn:p1"}, + "o": {"t": "l", "v": "value1"}, + }, + ] + result = wire_triples_to_tuples(wire) + assert len(result) == 1 + assert result[0] == ("urn:s1", "urn:p1", "value1") + + def test_multiple_triples(self): + wire = [ + { + "s": {"t": "i", "i": "urn:s1"}, + "p": {"t": "i", "i": "urn:p1"}, + "o": {"t": "l", "v": "v1"}, + }, + { + "s": {"t": "i", "i": "urn:s2"}, + "p": {"t": "i", "i": "urn:p2"}, + "o": {"t": "i", "i": "urn:o2"}, + }, + ] + result = wire_triples_to_tuples(wire) + assert len(result) == 2 + assert result[0] == ("urn:s1", "urn:p1", "v1") + assert result[1] == ("urn:s2", "urn:p2", "urn:o2") + + def test_empty_list(self): + assert wire_triples_to_tuples([]) == [] + + def test_missing_fields(self): + wire = [{"s": {}, "p": {}, "o": {}}] + result = wire_triples_to_tuples(wire) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# ExplainabilityClient +# --------------------------------------------------------------------------- + +def _make_wire_triples(tuples): + """Convert (s, p, o) tuples to wire format for mocking.""" + result = [] + for s, p, o in tuples: + entry = { + "s": {"t": "i", "i": s}, + "p": {"t": "i", "i": p}, + } + if o.startswith("urn:") or o.startswith("http"): + entry["o"] = {"t": "i", "i": o} + else: + entry["o"] = {"t": "l", "v": o} + result.append(entry) + return result + + +class TestExplainabilityClientFetchEntity: + + def test_fetch_question_entity(self): + wire = _make_wire_triples([ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "What is AI?"), + ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"), + ]) + + mock_flow = MagicMock() + # Return same results twice for quiescence + mock_flow.triples_query.side_effect = [wire, wire] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval") + + assert isinstance(entity, Question) + assert entity.query == "What is AI?" + assert entity.question_type == "graph-rag" + + def test_fetch_returns_none_when_no_data(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = [] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2) + entity = client.fetch_entity("urn:nonexistent") + + assert entity is None + + def test_fetch_retries_on_empty_results(self): + wire = _make_wire_triples([ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "Q"), + ]) + + mock_flow = MagicMock() + # Empty, then data, then same data (stable) + mock_flow.triples_query.side_effect = [[], wire, wire] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + entity = client.fetch_entity("urn:q:1") + + assert isinstance(entity, Question) + assert mock_flow.triples_query.call_count == 3 + + +class TestExplainabilityClientResolveLabel: + + def test_resolve_label_found(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = _make_wire_triples([ + ("urn:entity:1", RDFS_LABEL, "Entity One"), + ]) + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + label = client.resolve_label("urn:entity:1") + assert label == "Entity One" + + def test_resolve_label_not_found(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = [] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + label = client.resolve_label("urn:entity:1") + assert label == "urn:entity:1" + + def test_resolve_label_cached(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = _make_wire_triples([ + ("urn:entity:1", RDFS_LABEL, "Entity One"), + ]) + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + client.resolve_label("urn:entity:1") + client.resolve_label("urn:entity:1") + + # Only one query should be made + assert mock_flow.triples_query.call_count == 1 + + def test_resolve_label_non_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.resolve_label("plain text") == "plain text" + assert client.resolve_label("") == "" + mock_flow.triples_query.assert_not_called() + + def test_resolve_edge_labels(self): + mock_flow = MagicMock() + + def mock_query(s=None, p=None, **kwargs): + labels = { + "urn:e:Alice": "Alice", + "urn:r:knows": "knows", + "urn:e:Bob": "Bob", + } + if s in labels: + return _make_wire_triples([(s, RDFS_LABEL, labels[s])]) + return [] + + mock_flow.triples_query.side_effect = mock_query + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + s, p, o = client.resolve_edge_labels( + {"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"} + ) + assert s == "Alice" + assert p == "knows" + assert o == "Bob" + + +class TestExplainabilityClientContentFetching: + + def test_fetch_document_content_from_librarian(self): + mock_flow = MagicMock() + mock_api = MagicMock() + mock_library = MagicMock() + mock_api.library.return_value = mock_library + mock_library.get_document_content.return_value = b"librarian content" + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + result = client.fetch_document_content( + "urn:document:abc123", api=mock_api + ) + assert result == "librarian content" + + def test_fetch_document_content_truncated(self): + mock_flow = MagicMock() + mock_api = MagicMock() + mock_library = MagicMock() + mock_api.library.return_value = mock_library + mock_library.get_document_content.return_value = b"x" * 20000 + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + result = client.fetch_document_content( + "urn:doc:1", api=mock_api, max_content=100 + ) + assert len(result) < 20000 + assert result.endswith("... [truncated]") + + def test_fetch_document_content_empty_uri(self): + mock_flow = MagicMock() + mock_api = MagicMock() + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + result = client.fetch_document_content("", api=mock_api) + assert result == "" + + +class TestExplainabilityClientDetectSessionType: + + def test_detect_agent_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent" + + def test_detect_graphrag_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag" + + def test_detect_docrag_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag" diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py new file mode 100644 index 00000000..9aff7e4b --- /dev/null +++ b/tests/unit/test_provenance/test_triples.py @@ -0,0 +1,812 @@ +""" +Tests for provenance triple builder functions (extraction-time and query-time). +""" + +import pytest +from unittest.mock import patch + +from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE + +from trustgraph.provenance.triples import ( + set_graph, + document_triples, + derived_entity_triples, + subgraph_provenance_triples, + question_triples, + grounding_triples, + exploration_triples, + focus_triples, + synthesis_triples, + docrag_question_triples, + docrag_exploration_triples, + docrag_synthesis_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER, + TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, + TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, + TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS, + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANSWER_TYPE, + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + GRAPH_SOURCE, GRAPH_RETRIEVAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + """Find first triple matching predicate (and optionally subject).""" + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + """Find all triples matching predicate (and optionally subject).""" + return [ + t for t in triples + if t.p.iri == predicate and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + """Check if subject has rdf:type rdf_type.""" + for t in triples: + if (t.s.iri == subject and t.p.iri == RDF_TYPE + and t.o.type == IRI and t.o.iri == rdf_type): + return True + return False + + +# --------------------------------------------------------------------------- +# set_graph +# --------------------------------------------------------------------------- + +class TestSetGraph: + + def test_sets_graph_on_all_triples(self): + triples = [ + Triple( + s=Term(type=IRI, iri="urn:s1"), + p=Term(type=IRI, iri="urn:p1"), + o=Term(type=LITERAL, value="v1"), + ), + Triple( + s=Term(type=IRI, iri="urn:s2"), + p=Term(type=IRI, iri="urn:p2"), + o=Term(type=LITERAL, value="v2"), + ), + ] + result = set_graph(triples, GRAPH_RETRIEVAL) + assert len(result) == 2 + for t in result: + assert t.g == GRAPH_RETRIEVAL + + def test_does_not_modify_originals(self): + original = Triple( + s=Term(type=IRI, iri="urn:s"), + p=Term(type=IRI, iri="urn:p"), + o=Term(type=LITERAL, value="v"), + ) + result = set_graph([original], "urn:graph:test") + assert original.g is None + assert result[0].g == "urn:graph:test" + + def test_empty_list(self): + result = set_graph([], GRAPH_SOURCE) + assert result == [] + + def test_preserves_spo(self): + original = Triple( + s=Term(type=IRI, iri="urn:s"), + p=Term(type=IRI, iri="urn:p"), + o=Term(type=LITERAL, value="hello"), + ) + result = set_graph([original], "urn:g")[0] + assert result.s.iri == "urn:s" + assert result.p.iri == "urn:p" + assert result.o.value == "hello" + + +# --------------------------------------------------------------------------- +# document_triples +# --------------------------------------------------------------------------- + +class TestDocumentTriples: + + DOC_URI = "https://example.com/doc/abc" + + def test_minimal_document(self): + triples = document_triples(self.DOC_URI) + assert has_type(triples, self.DOC_URI, PROV_ENTITY) + assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE) + assert len(triples) == 2 + + def test_with_title(self): + triples = document_triples(self.DOC_URI, title="My Doc") + title_t = find_triple(triples, DC_TITLE) + assert title_t is not None + assert title_t.o.value == "My Doc" + # Title also creates an rdfs:label + label_t = find_triple(triples, RDFS_LABEL) + assert label_t is not None + assert label_t.o.value == "My Doc" + + def test_with_source(self): + triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf") + source_t = find_triple(triples, DC_SOURCE) + assert source_t is not None + assert source_t.o.type == IRI + assert source_t.o.iri == "https://source.com/f.pdf" + + def test_with_date(self): + triples = document_triples(self.DOC_URI, date="2024-01-15") + date_t = find_triple(triples, DC_DATE) + assert date_t is not None + assert date_t.o.value == "2024-01-15" + + def test_with_creator(self): + triples = document_triples(self.DOC_URI, creator="Alice") + creator_t = find_triple(triples, DC_CREATOR) + assert creator_t is not None + assert creator_t.o.value == "Alice" + + def test_with_page_count(self): + triples = document_triples(self.DOC_URI, page_count=42) + pc_t = find_triple(triples, TG_PAGE_COUNT) + assert pc_t is not None + assert pc_t.o.value == "42" + + def test_with_page_count_zero(self): + triples = document_triples(self.DOC_URI, page_count=0) + pc_t = find_triple(triples, TG_PAGE_COUNT) + assert pc_t is not None + assert pc_t.o.value == "0" + + def test_with_mime_type(self): + triples = document_triples(self.DOC_URI, mime_type="application/pdf") + mt_t = find_triple(triples, TG_MIME_TYPE) + assert mt_t is not None + assert mt_t.o.value == "application/pdf" + + def test_all_metadata(self): + triples = document_triples( + self.DOC_URI, + title="Test", + source="https://s.com", + date="2024-01-01", + creator="Bob", + page_count=10, + mime_type="application/pdf", + ) + # 2 type triples + title + label + source + date + creator + page_count + mime_type + assert len(triples) == 9 + + def test_subject_is_doc_uri(self): + triples = document_triples(self.DOC_URI, title="T") + for t in triples: + assert t.s.iri == self.DOC_URI + + +# --------------------------------------------------------------------------- +# derived_entity_triples +# --------------------------------------------------------------------------- + +class TestDerivedEntityTriples: + + ENTITY_URI = "https://example.com/doc/abc/p1" + PARENT_URI = "https://example.com/doc/abc" + + def test_page_entity_has_page_type(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + page_number=1, + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) + assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) + + def test_chunk_entity_has_chunk_type(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "chunker", "1.0", + chunk_index=0, + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE) + + def test_no_specific_type_without_page_or_chunk(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "component", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) + assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) + assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE) + + def test_was_derived_from_parent(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI) + assert derived is not None + assert derived.o.iri == self.PARENT_URI + + def test_activity_created(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + # Entity was generated by an activity + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI) + assert gen is not None + act_uri = gen.o.iri + + # Activity has correct type and metadata + assert has_type(triples, act_uri, PROV_ACTIVITY) + + # Activity used the parent + used = find_triple(triples, PROV_USED, act_uri) + assert used is not None + assert used.o.iri == self.PARENT_URI + + # Activity has component version + version = find_triple(triples, TG_COMPONENT_VERSION, act_uri) + assert version is not None + assert version.o.value == "1.0" + + def test_agent_created(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + # Find the agent URI via wasAssociatedWith + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI) + act_uri = gen.o.iri + assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri) + assert assoc is not None + agt_uri = assoc.o.iri + + assert has_type(triples, agt_uri, PROV_AGENT) + label = find_triple(triples, RDFS_LABEL, agt_uri) + assert label is not None + assert label.o.value == "pdf-extractor" + + def test_timestamp_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-06-15T12:30:00Z", + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME) + assert ts is not None + assert ts.o.value == "2024-06-15T12:30:00Z" + + def test_default_timestamp_generated(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_optional_label(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + label="Page 1", + timestamp="2024-01-01T00:00:00Z", + ) + label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI) + assert label is not None + assert label.o.value == "Page 1" + + def test_page_number_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + page_number=3, + timestamp="2024-01-01T00:00:00Z", + ) + pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI) + assert pn is not None + assert pn.o.value == "3" + + def test_chunk_metadata_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "chunker", "2.0", + chunk_index=5, + char_offset=1000, + char_length=500, + chunk_size=512, + chunk_overlap=64, + timestamp="2024-01-01T00:00:00Z", + ) + ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI) + assert ci is not None and ci.o.value == "5" + + co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI) + assert co is not None and co.o.value == "1000" + + cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI) + assert cl is not None and cl.o.value == "500" + + # chunk_size and chunk_overlap are on the activity, not the entity + cs = find_triple(triples, TG_CHUNK_SIZE) + assert cs is not None and cs.o.value == "512" + + ov = find_triple(triples, TG_CHUNK_OVERLAP) + assert ov is not None and ov.o.value == "64" + + +# --------------------------------------------------------------------------- +# subgraph_provenance_triples +# --------------------------------------------------------------------------- + +class TestSubgraphProvenanceTriples: + + SG_URI = "https://trustgraph.ai/subgraph/test-sg" + CHUNK_URI = "https://example.com/doc/abc/p1/c0" + + def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"): + return Triple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o), + ) + + def test_contains_quoted_triples(self): + extracted = [self._make_extracted_triple()] + triples = subgraph_provenance_triples( + self.SG_URI, extracted, self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 1 + assert contains[0].o.type == TRIPLE + assert contains[0].o.triple.s.iri == "urn:e:Alice" + assert contains[0].o.triple.p.iri == "urn:r:knows" + assert contains[0].o.triple.o.iri == "urn:e:Bob" + + def test_multiple_extracted_triples(self): + extracted = [ + self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"), + self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"), + self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"), + ] + triples = subgraph_provenance_triples( + self.SG_URI, extracted, self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 3 + + def test_empty_extracted_triples(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 0 + # Should still have subgraph provenance metadata + assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE) + + def test_subgraph_has_correct_types(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.SG_URI, PROV_ENTITY) + assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE) + + def test_derived_from_chunk(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI) + assert derived is not None + assert derived.o.iri == self.CHUNK_URI + + def test_activity_and_agent(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI) + assert gen is not None + act_uri = gen.o.iri + + assert has_type(triples, act_uri, PROV_ACTIVITY) + + used = find_triple(triples, PROV_USED, act_uri) + assert used is not None + assert used.o.iri == self.CHUNK_URI + + version = find_triple(triples, TG_COMPONENT_VERSION, act_uri) + assert version is not None + assert version.o.value == "1.0" + + def test_optional_llm_model(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + llm_model="claude-3-opus", + timestamp="2024-01-01T00:00:00Z", + ) + llm = find_triple(triples, TG_LLM_MODEL) + assert llm is not None + assert llm.o.value == "claude-3-opus" + + def test_no_llm_model_when_omitted(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + llm = find_triple(triples, TG_LLM_MODEL) + assert llm is None + + def test_optional_ontology(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + ontology_uri="https://example.com/ontology/v1", + timestamp="2024-01-01T00:00:00Z", + ) + ont = find_triple(triples, TG_ONTOLOGY) + assert ont is not None + assert ont.o.type == IRI + assert ont.o.iri == "https://example.com/ontology/v1" + + +# --------------------------------------------------------------------------- +# GraphRAG query-time triples +# --------------------------------------------------------------------------- + +class TestQuestionTriples: + + Q_URI = "urn:trustgraph:question:test-session" + + def test_question_types(self): + triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z") + assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, TG_QUESTION) + assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION) + + def test_question_query_text(self): + triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z") + query = find_triple(triples, TG_QUERY, self.Q_URI) + assert query is not None + assert query.o.value == "What is AI?" + + def test_question_timestamp(self): + triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI) + assert ts is not None + assert ts.o.value == "2024-06-15T10:00:00Z" + + def test_question_default_timestamp(self): + triples = question_triples(self.Q_URI, "Q") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_question_label(self): + triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + label = find_triple(triples, RDFS_LABEL, self.Q_URI) + assert label is not None + assert label.o.value == "GraphRAG Question" + + def test_question_triple_count(self): + triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + assert len(triples) == 6 + + +class TestGroundingTriples: + + GND_URI = "urn:trustgraph:prov:grounding:test-session" + Q_URI = "urn:trustgraph:question:test-session" + + def test_grounding_types(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"]) + assert has_type(triples, self.GND_URI, PROV_ENTITY) + assert has_type(triples, self.GND_URI, TG_GROUNDING) + + def test_grounding_generated_by_question(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"]) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI) + assert gen is not None + assert gen.o.iri == self.Q_URI + + def test_grounding_concepts(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"]) + concepts = find_triples(triples, TG_CONCEPT, self.GND_URI) + assert len(concepts) == 3 + values = {t.o.value for t in concepts} + assert values == {"AI", "ML", "robots"} + + def test_grounding_empty_concepts(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, []) + concepts = find_triples(triples, TG_CONCEPT, self.GND_URI) + assert len(concepts) == 0 + + def test_grounding_label(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, []) + label = find_triple(triples, RDFS_LABEL, self.GND_URI) + assert label is not None + assert label.o.value == "Grounding" + + +class TestExplorationTriples: + + EXP_URI = "urn:trustgraph:prov:exploration:test-session" + GND_URI = "urn:trustgraph:prov:grounding:test-session" + + def test_exploration_types(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) + assert has_type(triples, self.EXP_URI, PROV_ENTITY) + assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + + def test_exploration_derived_from_grounding(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI) + assert derived is not None + assert derived.o.iri == self.GND_URI + + def test_exploration_edge_count(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) + ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) + assert ec is not None + assert ec.o.value == "15" + + def test_exploration_zero_edges(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 0) + ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) + assert ec is not None + assert ec.o.value == "0" + + def test_exploration_with_entities(self): + entities = ["urn:e:machine-learning", "urn:e:neural-networks"] + triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities) + ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI) + assert len(ent_triples) == 2 + + def test_exploration_triple_count(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 10) + assert len(triples) == 5 + + +class TestFocusTriples: + + FOC_URI = "urn:trustgraph:prov:focus:test-session" + EXP_URI = "urn:trustgraph:prov:exploration:test-session" + SESSION_ID = "test-session" + + def test_focus_types(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + assert has_type(triples, self.FOC_URI, PROV_ENTITY) + assert has_type(triples, self.FOC_URI, TG_FOCUS) + + def test_focus_derived_from_exploration(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI) + assert derived is not None + assert derived.o.iri == self.EXP_URI + + def test_focus_no_edges(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + selected = find_triples(triples, TG_SELECTED_EDGE) + assert len(selected) == 0 + + def test_focus_with_edges_and_reasoning(self): + edges = [ + { + "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"), + "reasoning": "Alice is connected to Bob", + }, + { + "edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"), + "reasoning": "Bob works at Acme", + }, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + + # Two selectedEdge links + selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI) + assert len(selected) == 2 + + # Each edge selection has a quoted triple + edge_triples = find_triples(triples, TG_EDGE) + assert len(edge_triples) == 2 + for et in edge_triples: + assert et.o.type == TRIPLE + + # Each edge selection has reasoning + reasoning_triples = find_triples(triples, TG_REASONING) + assert len(reasoning_triples) == 2 + + def test_focus_edge_without_reasoning(self): + edges = [ + {"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""}, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + reasoning = find_triples(triples, TG_REASONING) + assert len(reasoning) == 0 + + def test_focus_edge_without_edge_data(self): + edges = [ + {"edge": None, "reasoning": "some reasoning"}, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + selected = find_triples(triples, TG_SELECTED_EDGE) + assert len(selected) == 0 + + def test_focus_quoted_triple_content(self): + edges = [ + { + "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"), + "reasoning": "test", + }, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + edge_t = find_triple(triples, TG_EDGE) + qt = edge_t.o.triple + assert qt.s.iri == "urn:e:Alice" + assert qt.p.iri == "urn:r:knows" + assert qt.o.iri == "urn:e:Bob" + + +class TestSynthesisTriples: + + SYN_URI = "urn:trustgraph:prov:synthesis:test-session" + FOC_URI = "urn:trustgraph:prov:focus:test-session" + + def test_synthesis_types(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + assert has_type(triples, self.SYN_URI, PROV_ENTITY) + assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) + assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE) + + def test_synthesis_derived_from_focus(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) + assert derived is not None + assert derived.o.iri == self.FOC_URI + + def test_synthesis_with_document_reference(self): + triples = synthesis_triples( + self.SYN_URI, self.FOC_URI, + document_id="urn:trustgraph:question:abc/answer", + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is not None + assert doc.o.type == IRI + assert doc.o.iri == "urn:trustgraph:question:abc/answer" + + def test_synthesis_no_document(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is None + + +# --------------------------------------------------------------------------- +# DocumentRAG query-time triples +# --------------------------------------------------------------------------- + +class TestDocRagQuestionTriples: + + Q_URI = "urn:trustgraph:docrag:test-session" + + def test_docrag_question_types(self): + triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z") + assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, TG_QUESTION) + assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION) + + def test_docrag_question_label(self): + triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + label = find_triple(triples, RDFS_LABEL, self.Q_URI) + assert label.o.value == "DocumentRAG Question" + + def test_docrag_question_query_text(self): + triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z") + query = find_triple(triples, TG_QUERY, self.Q_URI) + assert query.o.value == "search query" + + +class TestDocRagExplorationTriples: + + EXP_URI = "urn:trustgraph:docrag:test/exploration" + GND_URI = "urn:trustgraph:docrag:test/grounding" + + def test_docrag_exploration_types(self): + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5) + assert has_type(triples, self.EXP_URI, PROV_ENTITY) + assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + + def test_docrag_exploration_derived_from_grounding(self): + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI) + assert derived.o.iri == self.GND_URI + + def test_docrag_exploration_chunk_count(self): + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7) + cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI) + assert cc.o.value == "7" + + def test_docrag_exploration_without_chunk_ids(self): + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3) + chunks = find_triples(triples, TG_SELECTED_CHUNK) + assert len(chunks) == 0 + + def test_docrag_exploration_with_chunk_ids(self): + chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"] + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids) + chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI) + assert len(chunks) == 3 + chunk_uris = {t.o.iri for t in chunks} + assert chunk_uris == set(chunk_ids) + + +class TestDocRagSynthesisTriples: + + SYN_URI = "urn:trustgraph:docrag:test/synthesis" + EXP_URI = "urn:trustgraph:docrag:test/exploration" + + def test_docrag_synthesis_types(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + assert has_type(triples, self.SYN_URI, PROV_ENTITY) + assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) + + def test_docrag_synthesis_derived_from_exploration(self): + """DocRAG skips the focus step — synthesis derives from exploration.""" + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) + assert derived.o.iri == self.EXP_URI + + def test_docrag_synthesis_has_answer_type(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE) + + def test_docrag_synthesis_with_document(self): + triples = docrag_synthesis_triples( + self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans" + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc.o.iri == "urn:doc:ans" + + def test_docrag_synthesis_no_document(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is None diff --git a/tests/unit/test_provenance/test_uris.py b/tests/unit/test_provenance/test_uris.py new file mode 100644 index 00000000..0e69734c --- /dev/null +++ b/tests/unit/test_provenance/test_uris.py @@ -0,0 +1,292 @@ +""" +Tests for provenance URI generation functions. +""" + +import pytest +from unittest.mock import patch + +from trustgraph.provenance.uris import ( + TRUSTGRAPH_BASE, + _encode_id, + document_uri, + page_uri, + chunk_uri_from_page, + chunk_uri_from_doc, + activity_uri, + subgraph_uri, + agent_uri, + question_uri, + exploration_uri, + focus_uri, + synthesis_uri, + edge_selection_uri, + agent_session_uri, + agent_iteration_uri, + agent_final_uri, + docrag_question_uri, + docrag_exploration_uri, + docrag_synthesis_uri, +) + + +class TestEncodeId: + """Tests for the _encode_id helper.""" + + def test_plain_string(self): + assert _encode_id("abc123") == "abc123" + + def test_string_with_spaces(self): + assert _encode_id("hello world") == "hello%20world" + + def test_string_with_slashes(self): + assert _encode_id("a/b/c") == "a%2Fb%2Fc" + + def test_integer_input(self): + assert _encode_id(42) == "42" + + def test_empty_string(self): + assert _encode_id("") == "" + + def test_special_characters(self): + result = _encode_id("name@domain.com") + assert "@" not in result or result == "name%40domain.com" + + +class TestDocumentUris: + """Tests for document, page, and chunk URI generation.""" + + def test_document_uri_passthrough(self): + iri = "https://example.com/doc/123" + assert document_uri(iri) == iri + + def test_page_uri_format(self): + result = page_uri("https://example.com/doc/123", 5) + assert result == "https://example.com/doc/123/p5" + + def test_page_uri_page_zero(self): + result = page_uri("https://example.com/doc/123", 0) + assert result == "https://example.com/doc/123/p0" + + def test_chunk_uri_from_page_format(self): + result = chunk_uri_from_page("https://example.com/doc/123", 2, 3) + assert result == "https://example.com/doc/123/p2/c3" + + def test_chunk_uri_from_doc_format(self): + result = chunk_uri_from_doc("https://example.com/doc/123", 7) + assert result == "https://example.com/doc/123/c7" + + def test_page_uri_preserves_doc_iri(self): + doc = "urn:isbn:978-3-16-148410-0" + result = page_uri(doc, 1) + assert result.startswith(doc) + + def test_chunk_from_page_hierarchy(self): + """Chunk URI should contain both page and chunk identifiers.""" + result = chunk_uri_from_page("https://example.com/doc", 3, 5) + assert "/p3/" in result + assert result.endswith("/c5") + + +class TestActivityAndSubgraphUris: + """Tests for activity_uri, subgraph_uri, and agent_uri.""" + + def test_activity_uri_with_id(self): + result = activity_uri("my-activity-id") + assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id" + + def test_activity_uri_auto_generates_uuid(self): + result = activity_uri() + assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/") + # UUID part should be non-empty + uuid_part = result.split("/activity/")[1] + assert len(uuid_part) > 0 + + def test_activity_uri_unique_uuids(self): + r1 = activity_uri() + r2 = activity_uri() + assert r1 != r2 + + def test_activity_uri_encodes_special_chars(self): + result = activity_uri("id with spaces") + assert "id%20with%20spaces" in result + + def test_subgraph_uri_with_id(self): + result = subgraph_uri("sg-123") + assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123" + + def test_subgraph_uri_auto_generates_uuid(self): + result = subgraph_uri() + assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/") + uuid_part = result.split("/subgraph/")[1] + assert len(uuid_part) > 0 + + def test_subgraph_uri_unique_uuids(self): + r1 = subgraph_uri() + r2 = subgraph_uri() + assert r1 != r2 + + def test_agent_uri_format(self): + result = agent_uri("pdf-extractor") + assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor" + + def test_agent_uri_encodes_special_chars(self): + result = agent_uri("my component") + assert "my%20component" in result + + +class TestGraphRagQueryUris: + """Tests for GraphRAG query-time provenance URIs.""" + + FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000" + + def test_question_uri_with_session_id(self): + result = question_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:question:{self.FIXED_UUID}" + + def test_question_uri_auto_generates(self): + result = question_uri() + assert result.startswith("urn:trustgraph:question:") + uuid_part = result.split("urn:trustgraph:question:")[1] + assert len(uuid_part) > 0 + + def test_question_uri_unique(self): + r1 = question_uri() + r2 = question_uri() + assert r1 != r2 + + def test_exploration_uri_format(self): + result = exploration_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}" + + def test_focus_uri_format(self): + result = focus_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}" + + def test_synthesis_uri_format(self): + result = synthesis_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}" + + def test_edge_selection_uri_format(self): + result = edge_selection_uri(self.FIXED_UUID, 3) + assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3" + + def test_edge_selection_uri_zero_index(self): + result = edge_selection_uri(self.FIXED_UUID, 0) + assert result.endswith(":0") + + def test_session_uris_share_session_id(self): + """All URIs for a session should contain the same session ID.""" + sid = self.FIXED_UUID + q = question_uri(sid) + e = exploration_uri(sid) + f = focus_uri(sid) + s = synthesis_uri(sid) + for uri in [q, e, f, s]: + assert sid in uri + + +class TestAgentProvenanceUris: + """Tests for agent provenance URIs.""" + + FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000" + + def test_agent_session_uri_with_id(self): + result = agent_session_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}" + + def test_agent_session_uri_auto_generates(self): + result = agent_session_uri() + assert result.startswith("urn:trustgraph:agent:") + + def test_agent_session_uri_unique(self): + r1 = agent_session_uri() + r2 = agent_session_uri() + assert r1 != r2 + + def test_agent_iteration_uri_format(self): + result = agent_iteration_uri(self.FIXED_UUID, 1) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1" + + def test_agent_iteration_uri_numbering(self): + r1 = agent_iteration_uri(self.FIXED_UUID, 1) + r2 = agent_iteration_uri(self.FIXED_UUID, 2) + assert r1 != r2 + assert r1.endswith("/i1") + assert r2.endswith("/i2") + + def test_agent_final_uri_format(self): + result = agent_final_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final" + + def test_agent_uris_share_session_id(self): + sid = self.FIXED_UUID + session = agent_session_uri(sid) + iteration = agent_iteration_uri(sid, 1) + final = agent_final_uri(sid) + for uri in [session, iteration, final]: + assert sid in uri + + +class TestDocRagProvenanceUris: + """Tests for Document RAG provenance URIs.""" + + FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000" + + def test_docrag_question_uri_with_id(self): + result = docrag_question_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}" + + def test_docrag_question_uri_auto_generates(self): + result = docrag_question_uri() + assert result.startswith("urn:trustgraph:docrag:") + + def test_docrag_question_uri_unique(self): + r1 = docrag_question_uri() + r2 = docrag_question_uri() + assert r1 != r2 + + def test_docrag_exploration_uri_format(self): + result = docrag_exploration_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration" + + def test_docrag_synthesis_uri_format(self): + result = docrag_synthesis_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis" + + def test_docrag_uris_share_session_id(self): + sid = self.FIXED_UUID + q = docrag_question_uri(sid) + e = docrag_exploration_uri(sid) + s = docrag_synthesis_uri(sid) + for uri in [q, e, s]: + assert sid in uri + + +class TestUriNamespaceIsolation: + """Verify that different provenance types use distinct URI namespaces.""" + + FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000" + + def test_graphrag_vs_agent_namespace(self): + graphrag = question_uri(self.FIXED_UUID) + agent = agent_session_uri(self.FIXED_UUID) + assert graphrag != agent + assert "question" in graphrag + assert "agent" in agent + + def test_graphrag_vs_docrag_namespace(self): + graphrag = question_uri(self.FIXED_UUID) + docrag = docrag_question_uri(self.FIXED_UUID) + assert graphrag != docrag + + def test_agent_vs_docrag_namespace(self): + agent = agent_session_uri(self.FIXED_UUID) + docrag = docrag_question_uri(self.FIXED_UUID) + assert agent != docrag + + def test_extraction_vs_query_namespace(self): + """Extraction URIs use https://, query URIs use urn:.""" + ext = activity_uri(self.FIXED_UUID) + query = question_uri(self.FIXED_UUID) + assert ext.startswith("https://") + assert query.startswith("urn:") diff --git a/tests/unit/test_provenance/test_vocabulary.py b/tests/unit/test_provenance/test_vocabulary.py new file mode 100644 index 00000000..a3c644e8 --- /dev/null +++ b/tests/unit/test_provenance/test_vocabulary.py @@ -0,0 +1,124 @@ +""" +Tests for provenance vocabulary bootstrap. +""" + +import pytest + +from trustgraph.schema import Triple, Term, IRI, LITERAL + +from trustgraph.provenance.vocabulary import ( + get_vocabulary_triples, + PROV_CLASS_LABELS, + PROV_PREDICATE_LABELS, + DC_PREDICATE_LABELS, + SCHEMA_LABELS, + SKOS_LABELS, + TG_CLASS_LABELS, + TG_PREDICATE_LABELS, +) + +from trustgraph.provenance.namespaces import ( + RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, +) + + +class TestVocabularyTriples: + """Tests for the vocabulary bootstrap function.""" + + def test_returns_list_of_triples(self): + result = get_vocabulary_triples() + assert isinstance(result, list) + assert len(result) > 0 + for t in result: + assert isinstance(t, Triple) + + def test_all_triples_are_label_triples(self): + """Every vocabulary triple should use rdfs:label as predicate.""" + for t in get_vocabulary_triples(): + assert t.p.type == IRI + assert t.p.iri == RDFS_LABEL + + def test_all_subjects_are_iris(self): + for t in get_vocabulary_triples(): + assert t.s.type == IRI + assert len(t.s.iri) > 0 + + def test_all_objects_are_literals(self): + for t in get_vocabulary_triples(): + assert t.o.type == LITERAL + assert len(t.o.value) > 0 + + def test_no_duplicate_subjects(self): + subjects = [t.s.iri for t in get_vocabulary_triples()] + assert len(subjects) == len(set(subjects)) + + def test_includes_prov_classes(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert PROV_ENTITY in subjects + assert PROV_ACTIVITY in subjects + assert PROV_AGENT in subjects + + def test_includes_prov_predicates(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert PROV_WAS_DERIVED_FROM in subjects + assert PROV_WAS_GENERATED_BY in subjects + assert PROV_USED in subjects + assert PROV_WAS_ASSOCIATED_WITH in subjects + assert PROV_STARTED_AT_TIME in subjects + + def test_includes_dc_predicates(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert DC_TITLE in subjects + assert DC_SOURCE in subjects + assert DC_DATE in subjects + assert DC_CREATOR in subjects + + def test_includes_tg_classes(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert TG_DOCUMENT_TYPE in subjects + assert TG_PAGE_TYPE in subjects + assert TG_CHUNK_TYPE in subjects + assert TG_SUBGRAPH_TYPE in subjects + + def test_component_lists_sum_to_total(self): + total = get_vocabulary_triples() + components = ( + PROV_CLASS_LABELS + + PROV_PREDICATE_LABELS + + DC_PREDICATE_LABELS + + SCHEMA_LABELS + + SKOS_LABELS + + TG_CLASS_LABELS + + TG_PREDICATE_LABELS + ) + assert len(total) == len(components) + + def test_idempotent(self): + """Calling twice should return equivalent triples.""" + r1 = get_vocabulary_triples() + r2 = get_vocabulary_triples() + assert len(r1) == len(r2) + for t1, t2 in zip(r1, r2): + assert t1.s.iri == t2.s.iri + assert t1.o.value == t2.o.value + + +class TestNamespaceConstants: + """Verify namespace constants are well-formed IRIs.""" + + def test_prov_namespace_prefix(self): + assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#") + + def test_dc_namespace_prefix(self): + assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/") + + def test_tg_namespace_prefix(self): + assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/") + + def test_rdfs_label_iri(self): + assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label" diff --git a/tests/unit/test_query/conftest.py b/tests/unit/test_query/conftest.py index af707d88..8467b7d1 100644 --- a/tests/unit/test_query/conftest.py +++ b/tests/unit/test_query/conftest.py @@ -37,7 +37,7 @@ def mock_qdrant_client(): def mock_graph_embeddings_request(): """Mock graph embeddings request message""" mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -46,9 +46,9 @@ def mock_graph_embeddings_request(): @pytest.fixture def mock_graph_embeddings_multiple_vectors(): - """Mock graph embeddings request with multiple vectors""" + """Mock graph embeddings request with multiple vectors (legacy name, now single vector)""" mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2, 0.3, 0.4] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' @@ -82,7 +82,7 @@ def mock_graph_embeddings_uri_response(): def mock_document_embeddings_request(): """Mock document embeddings request message""" mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -91,9 +91,9 @@ def mock_document_embeddings_request(): @pytest.fixture def mock_document_embeddings_multiple_vectors(): - """Mock document embeddings request with multiple vectors""" + """Mock document embeddings request with multiple vectors (legacy name, now single vector)""" mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2, 0.3, 0.4] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' @@ -139,9 +139,9 @@ def mock_large_query_response(): @pytest.fixture def mock_mixed_dimension_vectors(): - """Mock request with vectors of different dimensions""" + """Mock request with vector (legacy name suggested mixed dimensions, now single vector)""" mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 622529e5..1cddce97 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.doc_embeddings.milvus.service import Processor -from trustgraph.schema import DocumentEmbeddingsRequest +from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch class TestMilvusDocEmbeddingsQueryProcessor: @@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -71,15 +71,15 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) # Mock search results mock_results = [ - {"entity": {"doc": "First document chunk"}}, - {"entity": {"doc": "Second document chunk"}}, - {"entity": {"doc": "Third document chunk"}}, + {"entity": {"chunk_id": "First document chunk"}}, + {"entity": {"chunk_id": "Second document chunk"}}, + {"entity": {"chunk_id": "Third document chunk"}}, ] processor.vecstore.search.return_value = mock_results @@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 ) - # Verify results are document chunks + # Verify results are ChunkMatch objects assert len(result) == 3 - assert result[0] == "First document chunk" - assert result[1] == "Second document chunk" - assert result[2] == "Third document chunk" + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == "First document chunk" + assert result[1].chunk_id == "Second document chunk" + assert result[2].chunk_id == "Third document chunk" @pytest.mark.asyncio - async def test_query_document_embeddings_multiple_vectors(self, processor): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_longer_vector(self, processor): + """Test querying document embeddings with a longer vector""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 ) - - # Mock search results - different results for each vector - mock_results_1 = [ - {"entity": {"doc": "Document from first vector"}}, - {"entity": {"doc": "Another doc from first vector"}}, + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "First document"}}, + {"entity": {"chunk_id": "Second document"}}, + {"entity": {"chunk_id": "Third document"}}, ] - mock_results_2 = [ - {"entity": {"doc": "Document from second vector"}}, - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results from all vectors are combined + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3 + ) + + # Verify results are ChunkMatch objects assert len(result) == 3 - assert "Document from first vector" in result - assert "Another doc from first vector" in result - assert "Document from second vector" in result + chunk_ids = [r.chunk_id for r in result] + assert "First document" in chunk_ids + assert "Second document" in chunk_ids + assert "Third document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_with_limit(self, processor): @@ -141,16 +135,16 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) # Mock search results - more results than limit mock_results = [ - {"entity": {"doc": "Document 1"}}, - {"entity": {"doc": "Document 2"}}, - {"entity": {"doc": "Document 3"}}, - {"entity": {"doc": "Document 4"}}, + {"entity": {"chunk_id": "Document 1"}}, + {"entity": {"chunk_id": "Document 2"}}, + {"entity": {"chunk_id": "Document 3"}}, + {"entity": {"chunk_id": "Document 4"}}, ] processor.vecstore.search.return_value = mock_results @@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -211,25 +205,26 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) # Mock search results with Unicode content mock_results = [ - {"entity": {"doc": "Document with Unicode: éñ中文🚀"}}, - {"entity": {"doc": "Regular ASCII document"}}, - {"entity": {"doc": "Document with émojis: 😀🎉"}}, + {"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}}, + {"entity": {"chunk_id": "Regular ASCII document"}}, + {"entity": {"chunk_id": "Document with émojis: 😀🎉"}}, ] processor.vecstore.search.return_value = mock_results result = await processor.query_document_embeddings(query) - # Verify Unicode content is preserved + # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with Unicode: éñ中文🚀" in result - assert "Regular ASCII document" in result - assert "Document with émojis: 😀🎉" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with Unicode: éñ中文🚀" in chunk_ids + assert "Regular ASCII document" in chunk_ids + assert "Document with émojis: 😀🎉" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_large_documents(self, processor): @@ -237,24 +232,25 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) # Mock search results with large content large_doc = "A" * 10000 # 10KB of content mock_results = [ - {"entity": {"doc": large_doc}}, - {"entity": {"doc": "Small document"}}, + {"entity": {"chunk_id": large_doc}}, + {"entity": {"chunk_id": "Small document"}}, ] processor.vecstore.search.return_value = mock_results result = await processor.query_document_embeddings(query) - # Verify large content is preserved + # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 - assert large_doc in result - assert "Small document" in result + chunk_ids = [r.chunk_id for r in result] + assert large_doc in chunk_ids + assert "Small document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_special_characters(self, processor): @@ -262,25 +258,26 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) # Mock search results with special characters mock_results = [ - {"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}}, - {"entity": {"doc": "Document with\nnewlines\tand\ttabs"}}, - {"entity": {"doc": "Document with special chars: @#$%^&*()"}}, + {"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}}, + {"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}}, + {"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}}, ] processor.vecstore.search.return_value = mock_results result = await processor.query_document_embeddings(query) - # Verify special characters are preserved + # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with \"quotes\" and 'apostrophes'" in result - assert "Document with\nnewlines\tand\ttabs" in result - assert "Document with special chars: @#$%^&*()" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids + assert "Document with\nnewlines\tand\ttabs" in chunk_ids + assert "Document with special chars: @#$%^&*()" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_zero_limit(self, processor): @@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=-1 ) @@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}] - mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}] - mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "Document 1"}}, + {"entity": {"chunk_id": "Document 2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 - assert "Document from 2D vector" in result - assert "Document from 4D vector" in result - assert "Document from 3D vector" in result + + # Verify search was called with the vector + processor.vecstore.search.assert_called_once() + + # Verify results are ChunkMatch objects + assert len(result) == 2 + chunk_ids = [r.chunk_id for r in result] + assert "Document 1" in chunk_ids + assert "Document 2" in chunk_ids @pytest.mark.asyncio - async def test_query_document_embeddings_duplicate_documents(self, processor): - """Test querying document embeddings with duplicate documents in results""" + async def test_query_document_embeddings_multiple_results(self, processor): + """Test querying document embeddings with multiple results""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates across vectors - mock_results_1 = [ - {"entity": {"doc": "Document A"}}, - {"entity": {"doc": "Document B"}}, + + # Mock search results with multiple documents + mock_results = [ + {"entity": {"chunk_id": "Document A"}}, + {"entity": {"chunk_id": "Document B"}}, + {"entity": {"chunk_id": "Document C"}}, ] - mock_results_2 = [ - {"entity": {"doc": "Document B"}}, # Duplicate - {"entity": {"doc": "Document C"}}, - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Note: Unlike graph embeddings, doc embeddings don't deduplicate - # This preserves ranking and allows multiple occurrences - assert len(result) == 4 - assert result.count("Document B") == 2 # Should appear twice - assert "Document A" in result - assert "Document C" in result + + # Verify results are ChunkMatch objects + assert len(result) == 3 + chunk_ids = [r.chunk_id for r in result] + assert "Document A" in chunk_ids + assert "Document B" in chunk_ids + assert "Document C" in chunk_ids def test_add_args_method(self): """Test that add_args properly configures argument parser""" @@ -458,5 +449,5 @@ class TestMilvusDocEmbeddingsQueryProcessor: mock_launch.assert_called_once_with( default_ident, - "\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n" + "\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n" ) \ No newline at end of file diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 4b067743..397bdf1b 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -18,10 +18,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: def mock_query_message(self): """Create a mock query message for testing""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -103,7 +100,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_single_vector(self, processor): """Test querying document embeddings with a single vector""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' @@ -179,7 +176,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_limit_handling(self, processor): """Test that query respects the limit parameter""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -208,7 +205,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_zero_limit(self, processor): """Test querying with zero limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 0 message.user = 'test_user' message.collection = 'test_collection' @@ -226,7 +223,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_negative_limit(self, processor): """Test querying with negative limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = -1 message.user = 'test_user' message.collection = 'test_collection' @@ -242,12 +239,9 @@ class TestPineconeDocEmbeddingsQueryProcessor: @pytest.mark.asyncio async def test_query_document_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions using same index""" + """Test querying with single vector (legacy test name, schema now uses single vector)""" message = MagicMock() - message.vectors = [ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6] # 4D vector - ] + message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -285,7 +279,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_vectors_list(self, processor): """Test querying with empty vectors list""" message = MagicMock() - message.vectors = [] + message.vector = [] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -304,7 +298,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_no_results(self, processor): """Test querying when index returns no results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -325,7 +319,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_unicode_content(self, processor): """Test querying document embeddings with Unicode content results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -351,7 +345,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_large_content(self, processor): """Test querying document embeddings with large content results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 1 message.user = 'test_user' message.collection = 'test_collection' @@ -377,7 +371,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_mixed_content_types(self, processor): """Test querying document embeddings with mixed content types""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -409,7 +403,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_exception_handling(self, processor): """Test that exceptions are properly raised""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -425,7 +419,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_index_access_failure(self, processor): """Test handling of index access failure""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -437,13 +431,9 @@ class TestPineconeDocEmbeddingsQueryProcessor: @pytest.mark.asyncio async def test_query_document_embeddings_vector_accumulation(self, processor): - """Test that results from multiple vectors are properly accumulated""" + """Test that results from single vector query are returned (legacy multi-vector test)""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] + message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index ad337f73..1d2f0e6d 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.doc_embeddings.qdrant.service import Processor +from trustgraph.schema import ChunkMatch class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -77,9 +78,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'first document chunk'} + mock_point1.payload = {'chunk_id': 'first document chunk'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'second document chunk'} + mock_point2.payload = {'chunk_id': 'second document chunk'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected documents + # Verify result contains expected ChunkMatch objects assert len(result) == 2 - # Results should be strings (document chunks) - assert isinstance(result[0], str) - assert isinstance(result[1], str) + # Results should be ChunkMatch objects + assert isinstance(result[0], ChunkMatch) + assert isinstance(result[1], ChunkMatch) # Verify content - assert result[0] == 'first document chunk' - assert result[1] == 'second document chunk' + assert result[0].chunk_id == 'first document chunk' + assert result[1].chunk_id == 'second document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') - async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client): + """Test querying document embeddings returns multiple results""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses for different vectors + + # Mock query response with multiple results mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document from vector 1'} + mock_point1.payload = {'chunk_id': 'document chunk 1'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document from vector 2'} + mock_point2.payload = {'chunk_id': 'document chunk 2'} mock_point3 = MagicMock() - mock_point3.payload = {'doc': 'another document from vector 2'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2, mock_point3] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point3.payload = {'chunk_id': 'document chunk 3'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2, mock_point3] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried correctly expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify results from both vectors are combined + + # Verify results are ChunkMatch objects assert len(result) == 3 - assert 'document from vector 1' in result - assert 'document from vector 2' in result - assert 'another document from vector 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document chunk 1' in chunk_ids + assert 'document chunk 2' in chunk_ids + assert 'document chunk 3' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -192,7 +190,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_points = [] for i in range(10): mock_point = MagicMock() - mock_point.payload = {'doc': f'document chunk {i}'} + mock_point.payload = {'chunk_id': f'document chunk {i}'} mock_points.append(mock_point) mock_response = MagicMock() @@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 3 # Should only return 3 results mock_message.user = 'limit_user' mock_message.collection = 'limit_collection' @@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'empty_user' mock_message.collection = 'empty_collection' @@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with different vector dimensions""" + """Test querying document embeddings with a higher dimension vector""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses + + # Mock query response mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document from 2D vector'} + mock_point1.payload = {'chunk_id': 'document from 5D vector'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document from 3D vector'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point2.payload = {'chunk_id': 'another 5D document'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with 5D vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once with correct collection + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection - assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions - assert calls[0][1]['query'] == [0.1, 0.2] + # Call should use 5D collection + assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions + assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 2 - assert 'document from 2D vector' in result - assert 'document from 3D vector' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document from 5D vector' in chunk_ids + assert 'another 5D document' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -326,9 +319,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response with UTF-8 content mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'} + mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'Chinese text: 你好世界'} + mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'utf8_user' mock_message.collection = 'utf8_collection' @@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 2 - - # Verify UTF-8 content works correctly - assert 'Document with UTF-8: café, naïve, résumé' in result - assert 'Chinese text: 你好世界' in result + + # Verify UTF-8 content works correctly in ChunkMatch objects + chunk_ids = [r.chunk_id for r in result] + assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids + assert 'Chinese text: 你好世界' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'error_user' mock_message.collection = 'error_collection' @@ -399,7 +393,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response mock_point = MagicMock() - mock_point.payload = {'doc': 'document chunk'} + mock_point.payload = {'chunk_id': 'document chunk'} mock_response = MagicMock() mock_response.points = [mock_point] mock_qdrant_instance.query_points.return_value = mock_response @@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with zero limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 0 mock_message.user = 'zero_user' mock_message.collection = 'zero_collection' @@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 0 - - # Result should contain all returned documents + + # Result should contain all returned documents as ChunkMatch objects assert len(result) == 1 - assert result[0] == 'document chunk' + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == 'document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -442,9 +437,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response with fewer results than limit mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document 1'} + mock_point1.payload = {'chunk_id': 'document 1'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document 2'} + mock_point2.payload = {'chunk_id': 'document 2'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with large limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 1000 # Large limit mock_message.user = 'large_user' mock_message.collection = 'large_collection' @@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 1000 - - # Result should contain all available documents + + # Result should contain all available documents as ChunkMatch objects assert len(result) == 2 - assert 'document 1' in result - assert 'document 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document 1' in chunk_ids + assert 'document 2' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -487,11 +483,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - # Mock query response with missing 'doc' key + # Mock query response with missing 'chunk_id' key mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'valid document'} + mock_point1.payload = {'chunk_id': 'valid document'} mock_point2 = MagicMock() - mock_point2.payload = {} # Missing 'doc' key + mock_point2.payload = {} # Missing 'chunk_id' key mock_point3 = MagicMock() mock_point3.payload = {'other_key': 'invalid'} # Wrong key @@ -508,13 +504,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'payload_user' mock_message.collection = 'payload_collection' # Act & Assert - # This should raise a KeyError when trying to access payload['doc'] + # This should raise a KeyError when trying to access payload['chunk_id'] with pytest.raises(KeyError): await processor.query_document_embeddings(mock_message) diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index 21b6e1bf..f2b8be7e 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.graph_embeddings.milvus.service import Processor -from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL +from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch class TestMilvusGraphEmbeddingsQueryProcessor: @@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 ) - # Verify results are converted to Term objects + # Verify results are converted to EntityMatch objects assert len(result) == 3 - assert isinstance(result[0], Term) - assert result[0].iri == "http://example.com/entity1" - assert result[0].type == IRI - assert isinstance(result[1], Term) - assert result[1].iri == "http://example.com/entity2" - assert result[1].type == IRI - assert isinstance(result[2], Term) - assert result[2].value == "literal entity" - assert result[2].type == LITERAL + assert isinstance(result[0], EntityMatch) + assert result[0].entity.iri == "http://example.com/entity1" + assert result[0].entity.type == IRI + assert isinstance(result[1], EntityMatch) + assert result[1].entity.iri == "http://example.com/entity2" + assert result[1].entity.type == IRI + assert isinstance(result[2], EntityMatch) + assert result[2].entity.value == "literal entity" + assert result[2].entity.type == LITERAL @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_multiple_results(self, processor): + """Test querying graph embeddings returns multiple results""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - limit=3 + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + limit=5 ) - - # Mock search results - different results for each vector - mock_results_1 = [ + + # Mock search results with multiple entities + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, - ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results are deduplicated and limited + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10 + ) + + # Verify results are EntityMatch objects assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values assert "http://example.com/entity3" in entity_values @@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) @@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 2 @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication(self, processor): - """Test that duplicate entities are properly deduplicated""" + async def test_query_graph_embeddings_preserves_order(self, processor): + """Test that query results preserve order from the vector store""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates - mock_results_1 = [ - {"entity": {"entity": "http://example.com/entity1"}}, - {"entity": {"entity": "http://example.com/entity2"}}, - ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity1"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity3"}}, # New - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - - result = await processor.query_graph_embeddings(query) - - # Verify duplicates are removed - assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] - assert len(set(entity_values)) == 3 # All unique - assert "http://example.com/entity1" in entity_values - assert "http://example.com/entity2" in entity_values - assert "http://example.com/entity3" in entity_values - @pytest.mark.asyncio - async def test_query_graph_embeddings_early_termination_on_limit(self, processor): - """Test that querying stops early when limit is reached""" - query = GraphEmbeddingsRequest( - user='test_user', - collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - limit=2 - ) - - # Mock search results - first vector returns enough results - mock_results_1 = [ + # Mock search results in specific order + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.return_value = mock_results_1 - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify only first vector was searched (limit reached) - processor.vecstore.search.assert_called_once_with( - [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 + + # Verify results are in the same order as returned by the store + assert len(result) == 3 + assert result[0].entity.iri == "http://example.com/entity1" + assert result[1].entity.iri == "http://example.com/entity2" + assert result[2].entity.iri == "http://example.com/entity3" + + @pytest.mark.asyncio + async def test_query_graph_embeddings_results_limited(self, processor): + """Test that results are properly limited when store returns more than requested""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + limit=2 ) - - # Verify results are limited + + # Mock search results - returns more results than limit + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "http://example.com/entity3"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_graph_embeddings(query) + + # Verify search was called with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4 + ) + + # Verify results are limited to requested amount assert len(result) == 2 @pytest.mark.asyncio @@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify all results are properly typed assert len(result) == 4 - + # Check URI entities - uri_results = [r for r in result if r.type == IRI] + uri_results = [r for r in result if r.entity.type == IRI] assert len(uri_results) == 2 - uri_values = [r.iri for r in uri_results] + uri_values = [r.entity.iri for r in uri_results] assert "http://example.com/uri_entity" in uri_values assert "https://example.com/another_uri" in uri_values - + # Check literal entities - literal_results = [r for r in result if not r.type == IRI] + literal_results = [r for r in result if not r.entity.type == IRI] assert len(literal_results) == 2 - literal_values = [r.value for r in literal_results] + literal_values = [r.entity.value for r in literal_results] assert "literal entity text" in literal_values assert "another literal" in literal_values @@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 0 @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying graph embeddings with different vector dimensions""" + async def test_query_graph_embeddings_longer_vector(self, processor): + """Test querying graph embeddings with a longer vector""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"entity": "entity_2d"}}] - mock_results_2 = [{"entity": {"entity": "entity_4d"}}] - mock_results_3 = [{"entity": {"entity": "entity_3d"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] - assert "entity_2d" in entity_values - assert "entity_4d" in entity_values - assert "entity_3d" in entity_values \ No newline at end of file + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once() + + # Verify results + assert len(result) == 2 + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] + assert "http://example.com/entity1" in entity_values + assert "http://example.com/entity2" in entity_values \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 1b243113..2c1a673a 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) from trustgraph.query.graph_embeddings.pinecone.service import Processor -from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.schema import Term, IRI, LITERAL, EntityMatch class TestPineconeGraphEmbeddingsQueryProcessor: @@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: def mock_query_message(self): """Create a mock query message for testing""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_single_vector(self, processor): """Test querying graph embeddings with a single vector""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' @@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor: include_metadata=True ) - # Verify results + # Verify results use EntityMatch structure assert len(entities) == 3 - assert entities[0].value == 'http://example.org/entity1' - assert entities[0].type == IRI - assert entities[1].value == 'entity2' - assert entities[1].type == LITERAL - assert entities[2].value == 'http://example.org/entity3' - assert entities[2].type == IRI + assert entities[0].entity.iri == 'http://example.org/entity1' + assert entities[0].entity.type == IRI + assert entities[1].entity.value == 'entity2' + assert entities[1].entity.type == LITERAL + assert entities[2].entity.iri == 'http://example.org/entity3' + assert entities[2].entity.type == IRI @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_basic(self, processor, mock_query_message): + """Test basic graph embeddings query""" # Mock index and query results mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # First query results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query results with distinct entities + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), - MagicMock(metadata={'entity': 'entity2'}) - ] - - # Second query results - mock_results2 = MagicMock() - mock_results2.matches = [ - MagicMock(metadata={'entity': 'entity2'}), # Duplicate + MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity3'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(mock_query_message) - - # Verify both queries were made - assert mock_index.query.call_count == 2 - - # Verify deduplication occurred - entity_values = [e.value for e in entities] + + # Verify query was made once + assert mock_index.query.call_count == 1 + + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert len(entity_values) == 3 assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_limit_handling(self, processor): """Test that query respects the limit parameter""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_zero_limit(self, processor): """Test querying with zero limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 0 message.user = 'test_user' message.collection = 'test_collection' @@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_negative_limit(self, processor): """Test querying with negative limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = -1 message.user = 'test_user' message.collection = 'test_collection' @@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions using same index""" + async def test_query_graph_embeddings_2d_vector(self, processor): + """Test querying with a 2D vector""" message = MagicMock() - message.vectors = [ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6] # 4D vector - ] + message.vector = [0.1, 0.2] # 2D vector message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' - # Mock single index that handles all dimensions + # Mock index mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - # Mock results for different vector queries - mock_results_2d = MagicMock() - mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] + # Mock results for 2D vector query + mock_results = MagicMock() + mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})] - mock_results_4d = MagicMock() - mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})] - - mock_index.query.side_effect = [mock_results_2d, mock_results_4d] + mock_index.query.return_value = mock_results entities = await processor.query_graph_embeddings(message) - # Verify different indexes used for different dimensions - assert processor.pinecone.Index.call_count == 2 - index_calls = processor.pinecone.Index.call_args_list - index_names = [call[0][0] for call in index_calls] - assert "t-test_user-test_collection-2" in index_names # 2D vector - assert "t-test_user-test_collection-4" in index_names # 4D vector + # Verify correct index used for 2D vector + processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") - # Verify both queries were made - assert mock_index.query.call_count == 2 + # Verify query was made + assert mock_index.query.call_count == 1 - # Verify results from both dimensions - entity_values = [e.value for e in entities] + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert 'entity_2d' in entity_values - assert 'entity_4d' in entity_values @pytest.mark.asyncio async def test_query_graph_embeddings_empty_vectors_list(self, processor): """Test querying with empty vectors list""" message = MagicMock() - message.vectors = [] + message.vector = [] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_no_results(self, processor): """Test querying when index returns no results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): - """Test that deduplication works correctly across multiple vector queries""" + async def test_query_graph_embeddings_deduplication_in_results(self, processor): + """Test that deduplication works correctly within query results""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' - + mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # Both queries return overlapping results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query returns results with some duplicates + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'entity1'}), # Duplicate MagicMock(metadata={'entity': 'entity3'}), - MagicMock(metadata={'entity': 'entity4'}) - ] - - mock_results2 = MagicMock() - mock_results2.matches = [ MagicMock(metadata={'entity': 'entity2'}), # Duplicate - MagicMock(metadata={'entity': 'entity3'}), # Duplicate - MagicMock(metadata={'entity': 'entity5'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(message) - + # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 - entity_values = [e.value for e in entities] + entity_values = [e.entity.value for e in entities] assert len(set(entity_values)) == 3 # All unique @pytest.mark.asyncio - async def test_query_graph_embeddings_early_termination_on_limit(self, processor): - """Test that querying stops early when limit is reached""" + async def test_query_graph_embeddings_respects_limit(self, processor): + """Test that query respects limit parameter""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' - + mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # First query returns enough results to meet limit - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query returns more results than limit + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity3'}) ] - mock_index.query.return_value = mock_results1 - + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(message) - - # Should only make one query since limit was reached + + # Should only return 2 entities (respecting limit) mock_index.query.assert_called_once() assert len(entities) == 2 @@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_exception_handling(self, processor): """Test that exceptions are properly raised""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 1760c4c1..9362a8dd 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.graph_embeddings.qdrant.service import Processor -from trustgraph.schema import IRI, LITERAL +from trustgraph.schema import IRI, LITERAL, EntityMatch class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected entities + # Verify result contains expected EntityMatch objects assert len(result) == 2 - assert all(hasattr(entity, 'value') for entity in result) - entity_values = [entity.value for entity in result] + assert all(isinstance(entity, EntityMatch) for entity in result) + entity_values = [entity.entity.value for entity in result] assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify deduplication - entity2 appears in both results but should only appear once - entity_values = [entity.value for entity in result] + + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert len(set(entity_values)) == len(entity_values) # All unique assert 'entity1' in entity_values assert 'entity2' in entity_values - assert 'entity3' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 3 # Should only return 3 results mock_message.user = 'limit_user' mock_message.collection = 'limit_collection' @@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'empty_user' mock_message.collection = 'empty_collection' @@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2] # 2D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection + # Call should use 2D collection assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['query'] == [0.1, 0.2] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results - entity_values = [entity.value for entity in result] + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert 'entity2d' in entity_values - assert 'entity3d' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'uri_user' mock_message.collection = 'uri_collection' @@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 3 - + # Check URI entities - uri_entities = [entity for entity in result if entity.type == IRI] + uri_entities = [entity for entity in result if entity.entity.type == IRI] assert len(uri_entities) == 2 - uri_values = [entity.iri for entity in uri_entities] + uri_values = [entity.entity.iri for entity in uri_entities] assert 'http://example.com/entity1' in uri_values assert 'https://secure.example.com/entity2' in uri_values # Check regular entities - regular_entities = [entity for entity in result if entity.type == LITERAL] + regular_entities = [entity for entity in result if entity.entity.type == LITERAL] assert len(regular_entities) == 1 - assert regular_entities[0].value == 'regular entity' + assert regular_entities[0].entity.value == 'regular entity' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'error_user' mock_message.collection = 'error_collection' @@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with zero limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 0 mock_message.user = 'zero_user' mock_message.collection = 'zero_collection' @@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # With zero limit, the logic still adds one entity before checking the limit # So it returns one result (current behavior, not ideal but actual) assert len(result) == 1 - assert result[0].value == 'entity1' + assert result[0].entity.value == 'entity1' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 480f2ee1..b620df7e 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -118,8 +118,8 @@ class TestCassandraQueryProcessor: # Verify result contains the queried triple assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'test_object' def test_processor_initialization_with_defaults(self): @@ -182,8 +182,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -219,8 +219,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -256,8 +256,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio @@ -293,8 +293,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -331,8 +331,8 @@ class TestCassandraQueryProcessor: mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 - assert result[0].s.value == 'all_subject' - assert result[0].p.value == 'all_predicate' + assert result[0].s.iri == 'all_subject' + assert result[0].p.iri == 'all_predicate' assert result[0].o.value == 'all_object' def test_add_args_method(self): @@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations: ) assert len(result) == 1 - assert result[0].s.value == 'result_subject' - assert result[0].p.value == 'test_predicate' + assert result[0].s.iri == 'result_subject' + assert result[0].p.iri == 'test_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations: ) assert len(result) == 1 - assert result[0].s.value == 'test_subject' - assert result[0].p.value == 'result_predicate' + assert result[0].s.iri == 'test_subject' + assert result[0].p.iri == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio @@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations: # Verify all results were returned assert len(result) == 5 for i, triple in enumerate(result): - assert triple.s.value == f'subject_{i}' # Mock returns literal values + assert triple.s.iri == f'subject_{i}' # Mock returns literal values assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' assert triple.p.type == IRI assert triple.o.iri == 'http://example.com/Person' # URIs use .iri diff --git a/tests/unit/test_rdf/__init__.py b/tests/unit/test_rdf/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_rdf/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_rdf/test_rdf_primitives.py b/tests/unit/test_rdf/test_rdf_primitives.py new file mode 100644 index 00000000..2498677b --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_primitives.py @@ -0,0 +1,309 @@ +""" +Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node, +typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass +with named graph support, and the knowledge/defs helper types. +""" + +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE + + +# --------------------------------------------------------------------------- +# Type constants +# --------------------------------------------------------------------------- + +class TestTypeConstants: + + def test_iri_constant(self): + assert IRI == "i" + + def test_blank_constant(self): + assert BLANK == "b" + + def test_literal_constant(self): + assert LITERAL == "l" + + def test_triple_constant(self): + assert TRIPLE == "t" + + def test_constants_are_distinct(self): + vals = {IRI, BLANK, LITERAL, TRIPLE} + assert len(vals) == 4 + + +# --------------------------------------------------------------------------- +# IRI terms +# --------------------------------------------------------------------------- + +class TestIriTerm: + + def test_create_iri(self): + t = Term(type=IRI, iri="http://example.org/Alice") + assert t.type == IRI + assert t.iri == "http://example.org/Alice" + + def test_iri_defaults_empty(self): + t = Term(type=IRI) + assert t.iri == "" + + def test_iri_with_fragment(self): + t = Term(type=IRI, iri="http://example.org/ontology#Person") + assert "#Person" in t.iri + + def test_iri_with_unicode(self): + t = Term(type=IRI, iri="http://example.org/概念") + assert "概念" in t.iri + + def test_iri_other_fields_default(self): + t = Term(type=IRI, iri="http://example.org/x") + assert t.id == "" + assert t.value == "" + assert t.datatype == "" + assert t.language == "" + assert t.triple is None + + +# --------------------------------------------------------------------------- +# Blank node terms +# --------------------------------------------------------------------------- + +class TestBlankNodeTerm: + + def test_create_blank_node(self): + t = Term(type=BLANK, id="_:b0") + assert t.type == BLANK + assert t.id == "_:b0" + + def test_blank_node_defaults_empty(self): + t = Term(type=BLANK) + assert t.id == "" + + def test_blank_node_arbitrary_id(self): + t = Term(type=BLANK, id="node-abc-123") + assert t.id == "node-abc-123" + + +# --------------------------------------------------------------------------- +# Typed literals (XSD datatypes) +# --------------------------------------------------------------------------- + +class TestTypedLiteral: + + def test_plain_literal(self): + t = Term(type=LITERAL, value="hello") + assert t.type == LITERAL + assert t.value == "hello" + assert t.datatype == "" + assert t.language == "" + + def test_xsd_integer(self): + t = Term( + type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer", + ) + assert t.value == "42" + assert "integer" in t.datatype + + def test_xsd_boolean(self): + t = Term( + type=LITERAL, value="true", + datatype="http://www.w3.org/2001/XMLSchema#boolean", + ) + assert t.datatype.endswith("#boolean") + + def test_xsd_date(self): + t = Term( + type=LITERAL, value="2026-03-13", + datatype="http://www.w3.org/2001/XMLSchema#date", + ) + assert t.value == "2026-03-13" + assert t.datatype.endswith("#date") + + def test_xsd_double(self): + t = Term( + type=LITERAL, value="3.14", + datatype="http://www.w3.org/2001/XMLSchema#double", + ) + assert t.datatype.endswith("#double") + + def test_empty_value_literal(self): + t = Term(type=LITERAL, value="") + assert t.value == "" + + +# --------------------------------------------------------------------------- +# Language-tagged literals +# --------------------------------------------------------------------------- + +class TestLanguageTaggedLiteral: + + def test_english_tag(self): + t = Term(type=LITERAL, value="hello", language="en") + assert t.language == "en" + assert t.datatype == "" + + def test_french_tag(self): + t = Term(type=LITERAL, value="bonjour", language="fr") + assert t.language == "fr" + + def test_bcp47_subtag(self): + t = Term(type=LITERAL, value="colour", language="en-GB") + assert t.language == "en-GB" + + def test_language_and_datatype_mutually_exclusive(self): + """Both can be set on the dataclass, but semantically only one should be used.""" + t = Term(type=LITERAL, value="x", language="en", + datatype="http://www.w3.org/2001/XMLSchema#string") + # Dataclass allows both — translators should respect mutual exclusivity + assert t.language == "en" + assert t.datatype != "" + + +# --------------------------------------------------------------------------- +# Quoted triples (RDF-star) +# --------------------------------------------------------------------------- + +class TestQuotedTriple: + + def test_term_with_nested_triple(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/Alice"), + p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"), + o=Term(type=IRI, iri="http://example.org/Bob"), + ) + qt = Term(type=TRIPLE, triple=inner) + assert qt.type == TRIPLE + assert qt.triple is inner + assert qt.triple.s.iri == "http://example.org/Alice" + + def test_quoted_triple_as_object(self): + """A triple whose object is a quoted triple (RDF-star).""" + inner = Triple( + s=Term(type=IRI, iri="http://example.org/Hope"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="A feeling of expectation"), + ) + outer = Triple( + s=Term(type=IRI, iri="urn:subgraph:123"), + p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"), + o=Term(type=TRIPLE, triple=inner), + ) + assert outer.o.type == TRIPLE + assert outer.o.triple.o.value == "A feeling of expectation" + + def test_quoted_triple_none(self): + t = Term(type=TRIPLE, triple=None) + assert t.triple is None + + +# --------------------------------------------------------------------------- +# Triple / Quad (named graph) +# --------------------------------------------------------------------------- + +class TestTripleQuad: + + def test_default_graph_is_none(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + assert t.g is None + + def test_named_graph(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + g="urn:graph:source", + ) + assert t.g == "urn:graph:source" + + def test_empty_string_graph(self): + t = Triple(g="") + assert t.g == "" + + def test_triple_with_all_none_terms(self): + t = Triple() + assert t.s is None + assert t.p is None + assert t.o is None + assert t.g is None + + def test_triple_equality(self): + """Dataclass equality based on field values.""" + t1 = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + t2 = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + assert t1 == t2 + + +# --------------------------------------------------------------------------- +# knowledge/defs helper types +# --------------------------------------------------------------------------- + +class TestKnowledgeDefs: + + def test_uri_type(self): + from trustgraph.knowledge.defs import Uri + u = Uri("http://example.org/x") + assert u.is_uri() is True + assert u.is_literal() is False + assert u.is_triple() is False + assert str(u) == "http://example.org/x" + + def test_literal_type(self): + from trustgraph.knowledge.defs import Literal + l = Literal("hello world") + assert l.is_uri() is False + assert l.is_literal() is True + assert l.is_triple() is False + assert str(l) == "hello world" + + def test_quoted_triple_type(self): + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + qt = QuotedTriple( + s=Uri("http://example.org/s"), + p=Uri("http://example.org/p"), + o=Literal("val"), + ) + assert qt.is_uri() is False + assert qt.is_literal() is False + assert qt.is_triple() is True + assert qt.s == "http://example.org/s" + assert qt.o == "val" + + def test_quoted_triple_repr(self): + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + qt = QuotedTriple( + s=Uri("http://example.org/A"), + p=Uri("http://example.org/B"), + o=Literal("C"), + ) + r = repr(qt) + assert "<<" in r + assert ">>" in r + assert "http://example.org/A" in r + + def test_quoted_triple_nested(self): + """QuotedTriple can contain another QuotedTriple as object.""" + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + inner = QuotedTriple( + s=Uri("http://example.org/s"), + p=Uri("http://example.org/p"), + o=Literal("v"), + ) + outer = QuotedTriple( + s=Uri("http://example.org/s2"), + p=Uri("http://example.org/p2"), + o=inner, + ) + assert outer.o.is_triple() is True diff --git a/tests/unit/test_rdf/test_rdf_storage_helpers.py b/tests/unit/test_rdf/test_rdf_storage_helpers.py new file mode 100644 index 00000000..7bd1807a --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_storage_helpers.py @@ -0,0 +1,217 @@ +""" +Tests for RDF storage helper functions used by the Cassandra triple writer: +serialize_triple, get_term_value, get_term_otype, get_term_dtype, get_term_lang. +""" + +import json +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE +from trustgraph.storage.triples.cassandra.write import ( + serialize_triple, + get_term_value, + get_term_otype, + get_term_dtype, + get_term_lang, +) + + +# --------------------------------------------------------------------------- +# get_term_otype — maps Term.type to storage object type code +# --------------------------------------------------------------------------- + +class TestGetTermOtype: + + def test_iri_maps_to_u(self): + assert get_term_otype(Term(type=IRI, iri="http://x")) == "u" + + def test_blank_maps_to_u(self): + assert get_term_otype(Term(type=BLANK, id="_:b0")) == "u" + + def test_literal_maps_to_l(self): + assert get_term_otype(Term(type=LITERAL, value="x")) == "l" + + def test_triple_maps_to_t(self): + assert get_term_otype(Term(type=TRIPLE)) == "t" + + def test_none_defaults_to_u(self): + assert get_term_otype(None) == "u" + + def test_unknown_type_defaults_to_u(self): + assert get_term_otype(Term(type="z")) == "u" + + +# --------------------------------------------------------------------------- +# get_term_dtype — extracts XSD datatype from literals +# --------------------------------------------------------------------------- + +class TestGetTermDtype: + + def test_literal_with_datatype(self): + t = Term(type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer") + assert get_term_dtype(t) == "http://www.w3.org/2001/XMLSchema#integer" + + def test_literal_without_datatype(self): + t = Term(type=LITERAL, value="hello") + assert get_term_dtype(t) == "" + + def test_iri_returns_empty(self): + assert get_term_dtype(Term(type=IRI, iri="http://x")) == "" + + def test_none_returns_empty(self): + assert get_term_dtype(None) == "" + + +# --------------------------------------------------------------------------- +# get_term_lang — extracts language tag from literals +# --------------------------------------------------------------------------- + +class TestGetTermLang: + + def test_literal_with_language(self): + t = Term(type=LITERAL, value="bonjour", language="fr") + assert get_term_lang(t) == "fr" + + def test_literal_without_language(self): + t = Term(type=LITERAL, value="hello") + assert get_term_lang(t) == "" + + def test_iri_returns_empty(self): + assert get_term_lang(Term(type=IRI, iri="http://x")) == "" + + def test_none_returns_empty(self): + assert get_term_lang(None) == "" + + def test_bcp47_subtag_preserved(self): + t = Term(type=LITERAL, value="colour", language="en-GB") + assert get_term_lang(t) == "en-GB" + + +# --------------------------------------------------------------------------- +# get_term_value — extracts string value from any Term +# --------------------------------------------------------------------------- + +class TestGetTermValue: + + def test_iri_returns_iri(self): + t = Term(type=IRI, iri="http://example.org/Alice") + assert get_term_value(t) == "http://example.org/Alice" + + def test_literal_returns_value(self): + t = Term(type=LITERAL, value="hello") + assert get_term_value(t) == "hello" + + def test_blank_returns_id(self): + t = Term(type=BLANK, id="_:b0") + assert get_term_value(t) == "_:b0" + + def test_none_returns_none(self): + assert get_term_value(None) is None + + def test_triple_returns_serialized_json(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + t = Term(type=TRIPLE, triple=inner) + result = get_term_value(t) + parsed = json.loads(result) + assert parsed["s"]["type"] == "i" + assert parsed["s"]["iri"] == "http://example.org/s" + assert parsed["o"]["value"] == "val" + + +# --------------------------------------------------------------------------- +# serialize_triple — full Triple → JSON serialization +# --------------------------------------------------------------------------- + +class TestSerializeTriple: + + def test_serialize_iri_triple(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=IRI, iri="http://example.org/B"), + ) + result = json.loads(serialize_triple(t)) + assert result["s"]["type"] == "i" + assert result["s"]["iri"] == "http://example.org/A" + assert result["p"]["iri"] == "http://example.org/rel" + assert result["o"]["type"] == "i" + + def test_serialize_literal_object(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="hello"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["type"] == "l" + assert result["o"]["value"] == "hello" + + def test_serialize_typed_literal(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["datatype"] == "http://www.w3.org/2001/XMLSchema#integer" + + def test_serialize_language_tagged_literal(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="bonjour", language="fr"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["language"] == "fr" + + def test_serialize_blank_node(self): + t = Triple( + s=Term(type=BLANK, id="_:b0"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + ) + result = json.loads(serialize_triple(t)) + assert result["s"]["type"] == "b" + assert result["s"]["id"] == "_:b0" + + def test_serialize_nested_quoted_triple(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/inner-s"), + p=Term(type=IRI, iri="http://example.org/inner-p"), + o=Term(type=LITERAL, value="inner-val"), + ) + outer = Triple( + s=Term(type=IRI, iri="http://example.org/outer-s"), + p=Term(type=IRI, iri="http://example.org/outer-p"), + o=Term(type=TRIPLE, triple=inner), + ) + result = json.loads(serialize_triple(outer)) + nested = json.loads(result["o"]["triple"]) + assert nested["s"]["iri"] == "http://example.org/inner-s" + assert nested["o"]["value"] == "inner-val" + + def test_serialize_none_returns_none(self): + assert serialize_triple(None) is None + + def test_serialize_none_terms(self): + t = Triple(s=None, p=None, o=None) + result = json.loads(serialize_triple(t)) + assert result["s"] is None + assert result["p"] is None + assert result["o"] is None + + def test_serialize_plain_literal_omits_datatype_and_language(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="plain"), + ) + result = json.loads(serialize_triple(t)) + assert "datatype" not in result["o"] + assert "language" not in result["o"] diff --git a/tests/unit/test_rdf/test_rdf_wire_format.py b/tests/unit/test_rdf/test_rdf_wire_format.py new file mode 100644 index 00000000..a0bbd27a --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_wire_format.py @@ -0,0 +1,357 @@ +""" +Tests for RDF wire format translators: TermTranslator and TripleTranslator +round-trip encoding for all RDF 1.2 term types (IRI, blank node, typed literal, +language-tagged literal, quoted triple) and named graph quads. +""" + +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE +from trustgraph.messaging.translators.primitives import ( + TermTranslator, TripleTranslator, SubgraphTranslator, +) + + +@pytest.fixture +def term_tx(): + return TermTranslator() + + +@pytest.fixture +def triple_tx(): + return TripleTranslator() + + +# --------------------------------------------------------------------------- +# TermTranslator — IRI +# --------------------------------------------------------------------------- + +class TestTermTranslatorIri: + + def test_iri_to_pulsar(self, term_tx): + data = {"t": "i", "i": "http://example.org/Alice"} + term = term_tx.to_pulsar(data) + assert term.type == IRI + assert term.iri == "http://example.org/Alice" + + def test_iri_from_pulsar(self, term_tx): + term = Term(type=IRI, iri="http://example.org/Bob") + wire = term_tx.from_pulsar(term) + assert wire == {"t": "i", "i": "http://example.org/Bob"} + + def test_iri_round_trip(self, term_tx): + original = Term(type=IRI, iri="http://example.org/round") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Blank node +# --------------------------------------------------------------------------- + +class TestTermTranslatorBlank: + + def test_blank_to_pulsar(self, term_tx): + data = {"t": "b", "d": "_:b42"} + term = term_tx.to_pulsar(data) + assert term.type == BLANK + assert term.id == "_:b42" + + def test_blank_from_pulsar(self, term_tx): + term = Term(type=BLANK, id="_:node1") + wire = term_tx.from_pulsar(term) + assert wire == {"t": "b", "d": "_:node1"} + + def test_blank_round_trip(self, term_tx): + original = Term(type=BLANK, id="_:x") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Typed literal (XSD) +# --------------------------------------------------------------------------- + +class TestTermTranslatorTypedLiteral: + + def test_plain_literal_to_pulsar(self, term_tx): + data = {"t": "l", "v": "hello"} + term = term_tx.to_pulsar(data) + assert term.type == LITERAL + assert term.value == "hello" + assert term.datatype == "" + assert term.language == "" + + def test_xsd_integer_to_pulsar(self, term_tx): + data = { + "t": "l", "v": "42", + "dt": "http://www.w3.org/2001/XMLSchema#integer", + } + term = term_tx.to_pulsar(data) + assert term.value == "42" + assert term.datatype.endswith("#integer") + + def test_typed_literal_from_pulsar(self, term_tx): + term = Term( + type=LITERAL, value="3.14", + datatype="http://www.w3.org/2001/XMLSchema#double", + ) + wire = term_tx.from_pulsar(term) + assert wire["t"] == "l" + assert wire["v"] == "3.14" + assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double" + assert "ln" not in wire # No language tag + + def test_typed_literal_round_trip(self, term_tx): + original = Term( + type=LITERAL, value="true", + datatype="http://www.w3.org/2001/XMLSchema#boolean", + ) + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + def test_plain_literal_omits_dt_and_ln(self, term_tx): + term = Term(type=LITERAL, value="x") + wire = term_tx.from_pulsar(term) + assert "dt" not in wire + assert "ln" not in wire + + +# --------------------------------------------------------------------------- +# TermTranslator — Language-tagged literal +# --------------------------------------------------------------------------- + +class TestTermTranslatorLangLiteral: + + def test_language_tag_to_pulsar(self, term_tx): + data = {"t": "l", "v": "bonjour", "ln": "fr"} + term = term_tx.to_pulsar(data) + assert term.value == "bonjour" + assert term.language == "fr" + + def test_language_tag_from_pulsar(self, term_tx): + term = Term(type=LITERAL, value="colour", language="en-GB") + wire = term_tx.from_pulsar(term) + assert wire["ln"] == "en-GB" + assert "dt" not in wire # No datatype + + def test_language_tag_round_trip(self, term_tx): + original = Term(type=LITERAL, value="hola", language="es") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Quoted triple (RDF-star) +# --------------------------------------------------------------------------- + +class TestTermTranslatorQuotedTriple: + + def test_quoted_triple_to_pulsar(self, term_tx): + data = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "http://example.org/Alice"}, + "p": {"t": "i", "i": "http://xmlns.com/foaf/0.1/knows"}, + "o": {"t": "i", "i": "http://example.org/Bob"}, + }, + } + term = term_tx.to_pulsar(data) + assert term.type == TRIPLE + assert term.triple is not None + assert term.triple.s.iri == "http://example.org/Alice" + assert term.triple.o.iri == "http://example.org/Bob" + + def test_quoted_triple_from_pulsar(self, term_tx): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + term = Term(type=TRIPLE, triple=inner) + wire = term_tx.from_pulsar(term) + assert wire["t"] == "t" + assert "tr" in wire + assert wire["tr"]["s"]["i"] == "http://example.org/s" + assert wire["tr"]["o"]["v"] == "val" + + def test_quoted_triple_round_trip(self, term_tx): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C", language="en"), + ) + original = Term(type=TRIPLE, triple=inner) + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored.type == TRIPLE + assert restored.triple.s == original.triple.s + assert restored.triple.o == original.triple.o + + def test_quoted_triple_none_triple(self, term_tx): + term = Term(type=TRIPLE, triple=None) + wire = term_tx.from_pulsar(term) + assert wire == {"t": "t"} + # And back + restored = term_tx.to_pulsar(wire) + assert restored.type == TRIPLE + assert restored.triple is None + + def test_quoted_triple_with_literal_object(self, term_tx): + data = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "http://example.org/Hope"}, + "p": {"t": "i", "i": "http://www.w3.org/2004/02/skos/core#definition"}, + "o": {"t": "l", "v": "A feeling of expectation"}, + }, + } + term = term_tx.to_pulsar(data) + assert term.triple.o.type == LITERAL + assert term.triple.o.value == "A feeling of expectation" + + +# --------------------------------------------------------------------------- +# TermTranslator — Edge cases +# --------------------------------------------------------------------------- + +class TestTermTranslatorEdgeCases: + + def test_unknown_type(self, term_tx): + data = {"t": "z"} + term = term_tx.to_pulsar(data) + assert term.type == "z" + + def test_empty_type(self, term_tx): + data = {} + term = term_tx.to_pulsar(data) + assert term.type == "" + + def test_missing_iri_field(self, term_tx): + data = {"t": "i"} + term = term_tx.to_pulsar(data) + assert term.iri == "" + + def test_missing_literal_fields(self, term_tx): + data = {"t": "l"} + term = term_tx.to_pulsar(data) + assert term.value == "" + assert term.datatype == "" + assert term.language == "" + + +# --------------------------------------------------------------------------- +# TripleTranslator +# --------------------------------------------------------------------------- + +class TestTripleTranslator: + + def test_triple_to_pulsar(self, triple_tx): + data = { + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "l", "v": "object"}, + } + triple = triple_tx.to_pulsar(data) + assert triple.s.iri == "http://example.org/s" + assert triple.o.value == "object" + assert triple.g is None + + def test_triple_from_pulsar(self, triple_tx): + triple = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + wire = triple_tx.from_pulsar(triple) + assert wire["s"]["t"] == "i" + assert wire["o"]["v"] == "C" + assert "g" not in wire + + def test_quad_with_named_graph(self, triple_tx): + data = { + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "l", "v": "val"}, + "g": "urn:graph:source", + } + quad = triple_tx.to_pulsar(data) + assert quad.g == "urn:graph:source" + + def test_quad_from_pulsar_includes_graph(self, triple_tx): + quad = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g="urn:graph:retrieval", + ) + wire = triple_tx.from_pulsar(quad) + assert wire["g"] == "urn:graph:retrieval" + + def test_quad_round_trip(self, triple_tx): + original = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g="urn:graph:source", + ) + wire = triple_tx.from_pulsar(original) + restored = triple_tx.to_pulsar(wire) + assert restored == original + + def test_none_graph_omitted_from_wire(self, triple_tx): + triple = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g=None, + ) + wire = triple_tx.from_pulsar(triple) + assert "g" not in wire + + def test_missing_terms_handled(self, triple_tx): + data = {} + triple = triple_tx.to_pulsar(data) + assert triple.s is None + assert triple.p is None + assert triple.o is None + + +# --------------------------------------------------------------------------- +# SubgraphTranslator +# --------------------------------------------------------------------------- + +class TestSubgraphTranslator: + + def test_subgraph_round_trip(self): + tx = SubgraphTranslator() + triples = [ + Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=LITERAL, value="v1"), + ), + Triple( + s=Term(type=IRI, iri="http://example.org/B"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=IRI, iri="http://example.org/C"), + g="urn:graph:source", + ), + ] + wire_list = tx.from_pulsar(triples) + assert len(wire_list) == 2 + assert wire_list[1]["g"] == "urn:graph:source" + + restored = tx.to_pulsar(wire_list) + assert len(restored) == 2 + assert restored[0] == triples[0] + assert restored[1] == triples[1] + + def test_empty_subgraph(self): + tx = SubgraphTranslator() + assert tx.to_pulsar([]) == [] + assert tx.from_pulsar([]) == [] diff --git a/tests/unit/test_reliability/__init__.py b/tests/unit/test_reliability/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_reliability/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py new file mode 100644 index 00000000..2fabed58 --- /dev/null +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -0,0 +1,144 @@ +""" +Tests for pipeline metadata preservation: DocumentMetadata and +ProcessingMetadata round-trip through translators, field preservation, +and default handling. +""" + +import pytest + +from trustgraph.schema import DocumentMetadata, ProcessingMetadata, Triple, Term, IRI +from trustgraph.messaging.translators.metadata import ( + DocumentMetadataTranslator, + ProcessingMetadataTranslator, +) + + +# --------------------------------------------------------------------------- +# DocumentMetadata translator +# --------------------------------------------------------------------------- + +class TestDocumentMetadataTranslator: + + def setup_method(self): + self.tx = DocumentMetadataTranslator() + + def test_full_round_trip(self): + data = { + "id": "doc-123", + "time": 1710000000, + "kind": "application/pdf", + "title": "Test Document", + "comments": "No comments", + "metadata": [], + "user": "alice", + "tags": ["finance", "q4"], + "parent-id": "doc-100", + "document-type": "page", + } + obj = self.tx.to_pulsar(data) + assert obj.id == "doc-123" + assert obj.time == 1710000000 + assert obj.kind == "application/pdf" + assert obj.title == "Test Document" + assert obj.user == "alice" + assert obj.tags == ["finance", "q4"] + assert obj.parent_id == "doc-100" + assert obj.document_type == "page" + + wire = self.tx.from_pulsar(obj) + assert wire["id"] == "doc-123" + assert wire["user"] == "alice" + assert wire["parent-id"] == "doc-100" + assert wire["document-type"] == "page" + + def test_defaults_for_missing_fields(self): + obj = self.tx.to_pulsar({}) + assert obj.parent_id == "" + assert obj.document_type == "source" + + def test_metadata_triples_preserved(self): + triple_wire = [{ + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "i", "i": "http://example.org/o"}, + }] + data = {"metadata": triple_wire} + obj = self.tx.to_pulsar(data) + assert len(obj.metadata) == 1 + assert obj.metadata[0].s.iri == "http://example.org/s" + + def test_none_metadata_handled(self): + data = {"metadata": None} + obj = self.tx.to_pulsar(data) + assert obj.metadata == [] + + def test_empty_tags_preserved(self): + data = {"tags": []} + obj = self.tx.to_pulsar(data) + wire = self.tx.from_pulsar(obj) + assert wire["tags"] == [] + + def test_falsy_fields_omitted_from_wire(self): + """Empty string fields should be omitted from wire format.""" + obj = DocumentMetadata(id="", time=0, user="") + wire = self.tx.from_pulsar(obj) + assert "id" not in wire + assert "user" not in wire + + +# --------------------------------------------------------------------------- +# ProcessingMetadata translator +# --------------------------------------------------------------------------- + +class TestProcessingMetadataTranslator: + + def setup_method(self): + self.tx = ProcessingMetadataTranslator() + + def test_full_round_trip(self): + data = { + "id": "proc-1", + "document-id": "doc-123", + "time": 1710000000, + "flow": "default", + "user": "alice", + "collection": "my-collection", + "tags": ["tag1"], + } + obj = self.tx.to_pulsar(data) + assert obj.id == "proc-1" + assert obj.document_id == "doc-123" + assert obj.flow == "default" + assert obj.user == "alice" + assert obj.collection == "my-collection" + assert obj.tags == ["tag1"] + + wire = self.tx.from_pulsar(obj) + assert wire["id"] == "proc-1" + assert wire["document-id"] == "doc-123" + assert wire["user"] == "alice" + assert wire["collection"] == "my-collection" + + def test_missing_fields_use_defaults(self): + obj = self.tx.to_pulsar({}) + assert obj.id is None + assert obj.user is None + assert obj.collection is None + + def test_tags_none_omitted(self): + obj = ProcessingMetadata(tags=None) + wire = self.tx.from_pulsar(obj) + assert "tags" not in wire + + def test_tags_empty_list_preserved(self): + obj = ProcessingMetadata(tags=[]) + wire = self.tx.from_pulsar(obj) + assert wire["tags"] == [] + + def test_user_and_collection_preserved(self): + """Core pipeline routing fields must survive round-trip.""" + data = {"user": "bob", "collection": "research"} + obj = self.tx.to_pulsar(data) + wire = self.tx.from_pulsar(obj) + assert wire["user"] == "bob" + assert wire["collection"] == "research" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py new file mode 100644 index 00000000..41a5c621 --- /dev/null +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -0,0 +1,314 @@ +""" +Tests for null embedding protection: empty/None vector skipping, entity +validation, dimension-aware collection creation, and query-time empty +vector handling. + +Tests the pure functions and logic without Qdrant connections. +""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from trustgraph.schema import Term, IRI, LITERAL, BLANK + + +# --------------------------------------------------------------------------- +# Graph embeddings: get_term_value +# --------------------------------------------------------------------------- + +class TestGraphEmbeddingsGetTermValue: + + def test_iri_returns_iri(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=IRI, iri="http://example.org/x") + assert get_term_value(t) == "http://example.org/x" + + def test_literal_returns_value(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=LITERAL, value="hello") + assert get_term_value(t) == "hello" + + def test_blank_returns_id(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=BLANK, id="_:b0") + assert get_term_value(t) == "_:b0" + + def test_none_returns_none(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + assert get_term_value(None) is None + + def test_blank_with_value_fallback(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=BLANK, id="", value="fallback") + assert get_term_value(t) == "fallback" + + +# --------------------------------------------------------------------------- +# Document embeddings: null vector protection +# --------------------------------------------------------------------------- + +class TestDocEmbeddingsNullProtection: + + @pytest.mark.asyncio + async def test_empty_vector_skipped(self): + """Embeddings with empty vectors should be silently skipped.""" + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + + # Mock collection_exists for config check + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = [] # Empty vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + + # No upsert should be called + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_none_vector_skipped(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = None # None vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_chunk_id_skipped(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "" # Empty chunk ID + emb.vector = [0.1, 0.2, 0.3] + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_valid_embedding_upserted(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = [0.1, 0.2, 0.3] + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_dimension_in_collection_name(self): + """Collection name should include vector dimension.""" + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "alice" + msg.metadata.collection = "docs" + + emb = MagicMock() + emb.chunk_id = "c1" + emb.vector = [0.0] * 384 # 384-dim vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + + call_args = proc.qdrant.upsert.call_args + assert "d_alice_docs_384" in call_args[1]["collection_name"] + + +# --------------------------------------------------------------------------- +# Graph embeddings: null entity and vector protection +# --------------------------------------------------------------------------- + +class TestGraphEmbeddingsNullProtection: + + @pytest.mark.asyncio + async def test_empty_entity_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="") # Empty IRI + entity.vector = [0.1, 0.2, 0.3] + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_none_entity_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = None # Null entity + entity.vector = [0.1, 0.2, 0.3] + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_vector_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/x") + entity.vector = [] # Empty vector + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_valid_entity_and_vector_upserted(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/Alice") + entity.vector = [0.1, 0.2, 0.3] + entity.chunk_id = "c1" + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_lazy_collection_creation_on_new_dimension(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = False + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "alice" + msg.metadata.collection = "graphs" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/x") + entity.vector = [0.0] * 768 + entity.chunk_id = "" + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + + # Collection should be created with correct dimension + proc.qdrant.create_collection.assert_called_once() + create_args = proc.qdrant.create_collection.call_args + assert create_args[1]["collection_name"] == "t_alice_graphs_768" + + +# --------------------------------------------------------------------------- +# Collection validation — deleted-while-in-flight protection +# --------------------------------------------------------------------------- + +class TestCollectionValidation: + + @pytest.mark.asyncio + async def test_doc_embeddings_dropped_for_deleted_collection(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=False) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "deleted-col" + msg.chunks = [MagicMock()] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_graph_embeddings_dropped_for_deleted_collection(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=False) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "deleted-col" + msg.entities = [MagicMock()] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() diff --git a/tests/unit/test_reliability/test_retry_backoff.py b/tests/unit/test_reliability/test_retry_backoff.py new file mode 100644 index 00000000..94a3e806 --- /dev/null +++ b/tests/unit/test_reliability/test_retry_backoff.py @@ -0,0 +1,153 @@ +""" +Tests for retry and backoff strategies: Consumer rate-limit retry loop, +timeout expiry, TooManyRequests exception propagation, and configurable +retry parameters. +""" + +import asyncio +import time +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.exceptions import TooManyRequests +from trustgraph.base.consumer import Consumer + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_consumer(rate_limit_retry_time=10, rate_limit_timeout=7200): + """Create a Consumer with minimal mocking.""" + consumer = Consumer.__new__(Consumer) + consumer.rate_limit_retry_time = rate_limit_retry_time + consumer.rate_limit_timeout = rate_limit_timeout + consumer.metrics = None + consumer.consumer = MagicMock() + return consumer + + +# --------------------------------------------------------------------------- +# TooManyRequests exception +# --------------------------------------------------------------------------- + +class TestTooManyRequestsException: + + def test_is_exception(self): + assert issubclass(TooManyRequests, Exception) + + def test_with_message(self): + err = TooManyRequests("rate limited") + assert "rate limited" in str(err) + + def test_without_message(self): + err = TooManyRequests() + assert isinstance(err, TooManyRequests) + + +# --------------------------------------------------------------------------- +# Consumer retry configuration +# --------------------------------------------------------------------------- + +class TestConsumerRetryConfig: + + def test_default_retry_time(self): + consumer = _make_consumer() + assert consumer.rate_limit_retry_time == 10 + + def test_default_timeout(self): + consumer = _make_consumer() + assert consumer.rate_limit_timeout == 7200 + + def test_custom_retry_time(self): + consumer = _make_consumer(rate_limit_retry_time=5) + assert consumer.rate_limit_retry_time == 5 + + def test_custom_timeout(self): + consumer = _make_consumer(rate_limit_timeout=300) + assert consumer.rate_limit_timeout == 300 + + +# --------------------------------------------------------------------------- +# Rate limit metrics +# --------------------------------------------------------------------------- + +class TestRateLimitMetrics: + + def test_metrics_rate_limit_called(self): + """Metrics should record rate limit events when available.""" + consumer = _make_consumer() + consumer.metrics = MagicMock() + + # Simulate what the consumer does on rate limit + consumer.metrics.rate_limit() + + consumer.metrics.rate_limit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message acknowledgment on error +# --------------------------------------------------------------------------- + +class TestMessageAckOnError: + + def test_consumer_has_negative_acknowledge(self): + """Consumer backend should support negative acknowledgment.""" + consumer = _make_consumer() + msg = MagicMock() + + # Simulate negative ack (what happens on timeout expiry) + consumer.consumer.negative_acknowledge(msg) + consumer.consumer.negative_acknowledge.assert_called_once_with(msg) + + +# --------------------------------------------------------------------------- +# TooManyRequests propagation across services +# --------------------------------------------------------------------------- + +class TestTooManyRequestsPropagation: + + def test_llm_service_propagates(self): + """LLM services should re-raise TooManyRequests for consumer retry.""" + with pytest.raises(TooManyRequests): + raise TooManyRequests() + + def test_embeddings_service_propagates(self): + """Embeddings services should re-raise TooManyRequests for consumer retry.""" + with pytest.raises(TooManyRequests): + try: + raise TooManyRequests("rate limited") + except TooManyRequests as e: + # Re-raise pattern used in services + assert isinstance(e, TooManyRequests) + raise + + def test_too_many_requests_not_caught_by_generic(self): + """TooManyRequests should be distinguishable from generic exceptions.""" + caught_specific = False + try: + raise TooManyRequests("rate limited") + except TooManyRequests: + caught_specific = True + except Exception: + pass + assert caught_specific + + +# --------------------------------------------------------------------------- +# Client-side error type mapping +# --------------------------------------------------------------------------- + +class TestClientErrorTypeMapping: + + def test_too_many_requests_wire_type(self): + """The wire format error type for rate limiting is 'too-many-requests'.""" + from trustgraph.schema import Error + err = Error(type="too-many-requests", message="slow down") + assert err.type == "too-many-requests" + + def test_generic_error_wire_type(self): + from trustgraph.schema import Error + err = Error(type="internal-error", message="something broke") + assert err.type == "internal-error" + assert err.type != "too-many-requests" diff --git a/tests/unit/test_reliability/test_subscriber_resilience.py b/tests/unit/test_reliability/test_subscriber_resilience.py new file mode 100644 index 00000000..4aac1161 --- /dev/null +++ b/tests/unit/test_reliability/test_subscriber_resilience.py @@ -0,0 +1,233 @@ +""" +Tests for message queue subscriber resilience: unexpected message handling, +orphaned message detection, backpressure strategies, graceful draining, +and timeout recovery. +""" + +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.base.subscriber import Subscriber + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_subscriber(max_size=10, backpressure_strategy="block", + drain_timeout=5.0): + """Create a Subscriber without connecting to any backend.""" + backend = MagicMock() + sub = Subscriber( + backend=backend, + topic="test-topic", + subscription="test-sub", + consumer_name="test", + max_size=max_size, + backpressure_strategy=backpressure_strategy, + drain_timeout=drain_timeout, + ) + sub.consumer = MagicMock() + return sub + + +def _make_msg(id=None, value="test-value"): + """Create a mock message with optional properties.""" + msg = MagicMock() + if id is not None: + msg.properties.return_value = {"id": id} + else: + msg.properties.side_effect = KeyError("id") + msg.value.return_value = value + return msg + + +# --------------------------------------------------------------------------- +# Message property extraction resilience +# --------------------------------------------------------------------------- + +class TestMessagePropertyResilience: + + @pytest.mark.asyncio + async def test_missing_id_property_handled(self): + """Messages without 'id' property should not crash.""" + sub = _make_subscriber() + msg = MagicMock() + msg.properties.side_effect = Exception("no properties") + msg.value.return_value = "some-value" + + # Should not raise + await sub._process_message(msg) + + # Message should still be acknowledged + sub.consumer.acknowledge.assert_called_once_with(msg) + + @pytest.mark.asyncio + async def test_message_with_valid_id_delivered(self): + """Messages with matching subscriber ID should be delivered.""" + sub = _make_subscriber() + q = await sub.subscribe("req-1") + + msg = _make_msg(id="req-1", value="response-data") + await sub._process_message(msg) + + assert not q.empty() + assert q.get_nowait() == "response-data" + sub.consumer.acknowledge.assert_called_once() + + +# --------------------------------------------------------------------------- +# Orphaned message handling +# --------------------------------------------------------------------------- + +class TestOrphanedMessages: + + @pytest.mark.asyncio + async def test_orphaned_message_acknowledged(self): + """Messages with no matching waiter should still be acknowledged.""" + sub = _make_subscriber() + msg = _make_msg(id="unknown-id", value="orphan") + + await sub._process_message(msg) + + # Orphaned message is acknowledged (not negative-acknowledged) + sub.consumer.acknowledge.assert_called_once_with(msg) + + @pytest.mark.asyncio + async def test_orphaned_message_not_queued(self): + """Orphaned messages should not appear in any subscriber queue.""" + sub = _make_subscriber() + q = await sub.subscribe("req-1") + + msg = _make_msg(id="different-id", value="orphan") + await sub._process_message(msg) + + assert q.empty() + + +# --------------------------------------------------------------------------- +# Backpressure strategies +# --------------------------------------------------------------------------- + +class TestBackpressureStrategies: + + @pytest.mark.asyncio + async def test_drop_new_rejects_when_full(self): + """drop_new strategy should reject new messages when queue is full.""" + sub = _make_subscriber(max_size=1, backpressure_strategy="drop_new") + q = await sub.subscribe("req-1") + + # Fill the queue + msg1 = _make_msg(id="req-1", value="first") + await sub._process_message(msg1) + assert q.qsize() == 1 + + # Second message should be dropped + msg2 = _make_msg(id="req-1", value="second") + await sub._process_message(msg2) + + # Queue still has only the first message + assert q.qsize() == 1 + assert q.get_nowait() == "first" + + @pytest.mark.asyncio + async def test_drop_oldest_evicts_when_full(self): + """drop_oldest strategy should evict oldest message when full.""" + sub = _make_subscriber(max_size=1, backpressure_strategy="drop_oldest") + q = await sub.subscribe("req-1") + + msg1 = _make_msg(id="req-1", value="first") + await sub._process_message(msg1) + + msg2 = _make_msg(id="req-1", value="second") + await sub._process_message(msg2) + + # Queue should have the newer message + assert q.qsize() == 1 + assert q.get_nowait() == "second" + + @pytest.mark.asyncio + async def test_block_strategy_delivers(self): + """block strategy should deliver messages normally.""" + sub = _make_subscriber(max_size=10, backpressure_strategy="block") + q = await sub.subscribe("req-1") + + msg = _make_msg(id="req-1", value="data") + await sub._process_message(msg) + + assert q.get_nowait() == "data" + + +# --------------------------------------------------------------------------- +# Full subscribers (subscribe_all) +# --------------------------------------------------------------------------- + +class TestFullSubscribers: + + @pytest.mark.asyncio + async def test_subscribe_all_receives_all_messages(self): + sub = _make_subscriber() + q = await sub.subscribe_all("listener-1") + + msg = _make_msg(id="any-id", value="broadcast") + await sub._process_message(msg) + + assert q.get_nowait() == "broadcast" + + @pytest.mark.asyncio + async def test_multiple_full_subscribers_all_receive(self): + sub = _make_subscriber() + q1 = await sub.subscribe_all("l1") + q2 = await sub.subscribe_all("l2") + + msg = _make_msg(id="any", value="data") + await sub._process_message(msg) + + assert q1.get_nowait() == "data" + assert q2.get_nowait() == "data" + + +# --------------------------------------------------------------------------- +# Subscribe / unsubscribe lifecycle +# --------------------------------------------------------------------------- + +class TestSubscribeLifecycle: + + @pytest.mark.asyncio + async def test_unsubscribe_removes_queue(self): + sub = _make_subscriber() + await sub.subscribe("req-1") + await sub.unsubscribe("req-1") + + assert "req-1" not in sub.q + + @pytest.mark.asyncio + async def test_unsubscribe_nonexistent_is_noop(self): + sub = _make_subscriber() + await sub.unsubscribe("nonexistent") # Should not raise + + @pytest.mark.asyncio + async def test_unsubscribe_all_removes_queue(self): + sub = _make_subscriber() + await sub.subscribe_all("l1") + await sub.unsubscribe_all("l1") + + assert "l1" not in sub.full + + +# --------------------------------------------------------------------------- +# Pending ack tracking +# --------------------------------------------------------------------------- + +class TestPendingAckTracking: + + @pytest.mark.asyncio + async def test_processed_message_cleared_from_pending(self): + sub = _make_subscriber() + msg = _make_msg(id="req-1", value="data") + + await sub._process_message(msg) + + # After processing, pending_acks should be empty + assert len(sub.pending_acks) == 0 diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 590572bc..27508ba4 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -8,48 +8,75 @@ from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query +# Sample chunk content mapping for tests +CHUNK_CONTENT = { + "doc/c1": "Document 1 content", + "doc/c2": "Document 2 content", + "doc/c3": "Relevant document content", + "doc/c4": "Another document", + "doc/c5": "Default doc", + "doc/c6": "Verbose test doc", + "doc/c7": "Verbose doc content", + "doc/ml1": "Machine learning is a subset of artificial intelligence...", + "doc/ml2": "ML algorithms learn patterns from data to make predictions...", + "doc/ml3": "Common ML techniques include supervised and unsupervised learning...", +} + + +@pytest.fixture +def mock_fetch_chunk(): + """Create a mock fetch_chunk function""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + + class TestDocumentRag: """Test cases for DocumentRag class""" - def test_document_rag_initialization_with_defaults(self): + def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk): """Test DocumentRag initialization with default verbose setting""" # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_doc_embeddings_client = MagicMock() - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, - doc_embeddings_client=mock_doc_embeddings_client + doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk ) - + # Verify initialization assert document_rag.prompt_client == mock_prompt_client assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client + assert document_rag.fetch_chunk == mock_fetch_chunk assert document_rag.verbose is False # Default value - def test_document_rag_initialization_with_verbose(self): + def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk): """Test DocumentRag initialization with verbose enabled""" # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_doc_embeddings_client = MagicMock() - + # Initialize DocumentRag with verbose=True document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) - + # Verify initialization assert document_rag.prompt_client == mock_prompt_client assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client + assert document_rag.fetch_chunk == mock_fetch_chunk assert document_rag.verbose is True @@ -60,7 +87,7 @@ class TestQuery: """Test Query initialization with default parameters""" # Create mock DocumentRag mock_rag = MagicMock() - + # Initialize Query with defaults query = Query( rag=mock_rag, @@ -68,7 +95,7 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "test_user" @@ -80,7 +107,7 @@ class TestQuery: """Test Query initialization with custom doc_limit""" # Create mock DocumentRag mock_rag = MagicMock() - + # Initialize Query with custom doc_limit query = Query( rag=mock_rag, @@ -89,7 +116,7 @@ class TestQuery: verbose=True, doc_limit=50 ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "custom_user" @@ -98,54 +125,101 @@ class TestQuery: assert query.doc_limit == 50 @pytest.mark.asyncio - async def test_get_vector_method(self): - """Test Query.get_vector method calls embeddings client correctly""" - # Create mock DocumentRag with embeddings client + async def test_extract_concepts(self): + """Test Query.extract_concepts extracts concepts from query""" mock_rag = MagicMock() - mock_embeddings_client = AsyncMock() - mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method to return test vectors - expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - mock_embeddings_client.embed.return_value = expected_vectors - - # Initialize Query + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + # Mock the prompt response with concept lines + mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns" + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call get_vector - test_query = "What documents are relevant?" - result = await query.get_vector(test_query) - - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify result matches expected vectors + + result = await query.extract_concepts("What is machine learning?") + + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "What is machine learning?"} + ) + assert result == ["machine learning", "artificial intelligence", "data patterns"] + + @pytest.mark.asyncio + async def test_extract_concepts_fallback_to_raw_query(self): + """Test Query.extract_concepts falls back to raw query when no concepts extracted""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + # Mock empty response + mock_prompt_client.prompt.return_value = "" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("What is ML?") + + assert result == ["What is ML?"] + + @pytest.mark.asyncio + async def test_get_vectors_method(self): + """Test Query.get_vectors method calls embeddings client correctly""" + mock_rag = MagicMock() + mock_embeddings_client = AsyncMock() + mock_rag.embeddings_client = mock_embeddings_client + + # Mock the embed method - returns vectors for each concept + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + mock_embeddings_client.embed.return_value = expected_vectors + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + concepts = ["machine learning", "data patterns"] + result = await query.get_vectors(concepts) + + mock_embeddings_client.embed.assert_called_once_with(concepts) assert result == expected_vectors @pytest.mark.asyncio async def test_get_docs_method(self): """Test Query.get_docs method retrieves documents correctly""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - - # Mock the embedding and document query responses - test_vectors = [[0.1, 0.2, 0.3]] - mock_embeddings_client.embed.return_value = test_vectors - - # Mock document results - test_docs = ["Document 1 content", "Document 2 content"] - mock_doc_embeddings_client.query.return_value = test_docs - - # Initialize Query + + # Mock fetch_chunk function + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + + # Mock embeddings - one vector per concept + mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]] + + # Mock document embeddings returns ChunkMatch objects + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c1" + mock_match1.score = 0.95 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c2" + mock_match2.score = 0.85 + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] + query = Query( rag=mock_rag, user="test_user", @@ -153,126 +227,154 @@ class TestQuery: verbose=False, doc_limit=15 ) - - # Call get_docs - test_query = "Find relevant documents" - result = await query.get_docs(test_query) - - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify doc embeddings client was called correctly + + # Call get_docs with concepts list + concepts = ["test concept"] + result = await query.get_docs(concepts) + + # Verify embeddings client was called with concepts + mock_embeddings_client.embed.assert_called_once_with(concepts) + + # Verify doc embeddings client was called mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=[0.1, 0.2, 0.3], limit=15, user="test_user", collection="test_collection" ) - - # Verify result is list of documents - assert result == test_docs + + # Verify result is tuple of (docs, chunk_ids) + docs, chunk_ids = result + assert "Document 1 content" in docs + assert "Document 2 content" in docs + assert "doc/c1" in chunk_ids + assert "doc/c2" in chunk_ids @pytest.mark.asyncio - async def test_document_rag_query_method(self): + async def test_document_rag_query_method(self, mock_fetch_chunk): """Test DocumentRag.query method orchestrates full document RAG pipeline""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - - # Mock embeddings and document responses + + # Mock concept extraction + mock_prompt_client.prompt.return_value = "test concept" + + # Mock embeddings - one vector per concept test_vectors = [[0.1, 0.2, 0.3]] - test_docs = ["Relevant document content", "Another document"] - expected_response = "This is the document RAG response" - mock_embeddings_client.embed.return_value = test_vectors - mock_doc_embeddings_client.query.return_value = test_docs + + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c3" + mock_match1.score = 0.9 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c4" + mock_match2.score = 0.8 + expected_response = "This is the document RAG response" + + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_prompt_client.document_prompt.return_value = expected_response - - # Initialize DocumentRag + document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - - # Call DocumentRag.query + result = await document_rag.query( query="test query", user="test_user", collection="test_collection", doc_limit=10 ) - - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with("test query") - + + # Verify concept extraction was called + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "test query"} + ) + + # Verify embeddings called with extracted concepts + mock_embeddings_client.embed.assert_called_once_with(["test concept"]) + # Verify doc embeddings client was called mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=[0.1, 0.2, 0.3], limit=10, user="test_user", collection="test_collection" ) - - # Verify prompt client was called with documents and query - mock_prompt_client.document_prompt.assert_called_once_with( - query="test query", - documents=test_docs - ) - - # Verify result + + # Verify prompt client was called with fetched documents and query + mock_prompt_client.document_prompt.assert_called_once() + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == "test query" + docs = call_args.kwargs["documents"] + assert "Relevant document content" in docs + assert "Another document" in docs + assert result == expected_response @pytest.mark.asyncio - async def test_document_rag_query_with_defaults(self): + async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): """Test DocumentRag.query method with default parameters""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - + + # Mock concept extraction fallback (empty → raw query) + mock_prompt_client.prompt.return_value = "" + # Mock responses - mock_embeddings_client.embed.return_value = [[0.1, 0.2]] - mock_doc_embeddings_client.query.return_value = ["Default doc"] + mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c5" + mock_match.score = 0.9 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Default response" - - # Initialize DocumentRag + document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, - doc_embeddings_client=mock_doc_embeddings_client + doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk ) - - # Call DocumentRag.query with minimal parameters + result = await document_rag.query("simple query") - + # Verify default parameters were used mock_doc_embeddings_client.query.assert_called_once_with( - [[0.1, 0.2]], + vector=[[0.1, 0.2]], limit=20, # Default doc_limit user="trustgraph", # Default user collection="default" # Default collection ) - + assert result == "Default response" @pytest.mark.asyncio async def test_get_docs_with_verbose_output(self): """Test Query.get_docs method with verbose logging""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - - # Mock responses - mock_embeddings_client.embed.return_value = [[0.7, 0.8]] - mock_doc_embeddings_client.query.return_value = ["Verbose test doc"] - - # Initialize Query with verbose=True + + # Mock fetch_chunk + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + + # Mock responses - one vector per concept + mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c6" + mock_match.score = 0.88 + mock_doc_embeddings_client.query.return_value = [mock_match] + query = Query( rag=mock_rag, user="test_user", @@ -280,196 +382,245 @@ class TestQuery: verbose=True, doc_limit=5 ) - - # Call get_docs - result = await query.get_docs("verbose test") - - # Verify calls were made - mock_embeddings_client.embed.assert_called_once_with("verbose test") + + # Call get_docs with concepts + result = await query.get_docs(["verbose test"]) + + mock_embeddings_client.embed.assert_called_once_with(["verbose test"]) mock_doc_embeddings_client.query.assert_called_once() - - # Verify result - assert result == ["Verbose test doc"] + + docs, chunk_ids = result + assert "Verbose test doc" in docs + assert "doc/c6" in chunk_ids @pytest.mark.asyncio - async def test_document_rag_query_with_verbose(self): + async def test_document_rag_query_with_verbose(self, mock_fetch_chunk): """Test DocumentRag.query method with verbose logging enabled""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - + + # Mock concept extraction + mock_prompt_client.prompt.return_value = "verbose query test" + # Mock responses - mock_embeddings_client.embed.return_value = [[0.3, 0.4]] - mock_doc_embeddings_client.query.return_value = ["Verbose doc content"] + mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c7" + mock_match.score = 0.92 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" - - # Initialize DocumentRag with verbose=True + document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) - - # Call DocumentRag.query + result = await document_rag.query("verbose query test") - - # Verify all clients were called - mock_embeddings_client.embed.assert_called_once_with("verbose query test") + + mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() - mock_prompt_client.document_prompt.assert_called_once_with( - query="verbose query test", - documents=["Verbose doc content"] - ) - + + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == "verbose query test" + assert "Verbose doc content" in call_args.kwargs["documents"] + assert result == "Verbose RAG response" @pytest.mark.asyncio async def test_get_docs_with_empty_results(self): """Test Query.get_docs method when no documents are found""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - - # Mock responses - empty document list - mock_embeddings_client.embed.return_value = [[0.1, 0.2]] - mock_doc_embeddings_client.query.return_value = [] # No documents found - - # Initialize Query + + async def mock_fetch(chunk_id, user): + return f"Content for {chunk_id}" + mock_rag.fetch_chunk = mock_fetch + + # Mock responses - empty results + mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] + mock_doc_embeddings_client.query.return_value = [] + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call get_docs - result = await query.get_docs("query with no results") - - # Verify calls were made - mock_embeddings_client.embed.assert_called_once_with("query with no results") + + result = await query.get_docs(["query with no results"]) + + mock_embeddings_client.embed.assert_called_once_with(["query with no results"]) mock_doc_embeddings_client.query.assert_called_once() - - # Verify empty result is returned - assert result == [] + + assert result == ([], []) @pytest.mark.asyncio - async def test_document_rag_query_with_empty_documents(self): + async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk): """Test DocumentRag.query method when no documents are retrieved""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - - # Mock responses - no documents found - mock_embeddings_client.embed.return_value = [[0.5, 0.6]] - mock_doc_embeddings_client.query.return_value = [] # Empty document list + + # Mock concept extraction + mock_prompt_client.prompt.return_value = "query with no matching docs" + + mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] + mock_doc_embeddings_client.query.return_value = [] mock_prompt_client.document_prompt.return_value = "No documents found response" - - # Initialize DocumentRag + document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - - # Call DocumentRag.query + result = await document_rag.query("query with no matching docs") - - # Verify prompt client was called with empty document list + mock_prompt_client.document_prompt.assert_called_once_with( query="query with no matching docs", documents=[] ) - + assert result == "No documents found response" @pytest.mark.asyncio - async def test_get_vector_with_verbose(self): - """Test Query.get_vector method with verbose logging""" - # Create mock DocumentRag with embeddings client + async def test_get_vectors_with_verbose(self): + """Test Query.get_vectors method with verbose logging""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method + expected_vectors = [[0.9, 1.0, 1.1]] mock_embeddings_client.embed.return_value = expected_vectors - - # Initialize Query with verbose=True + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=True ) - - # Call get_vector - result = await query.get_vector("verbose vector test") - - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with("verbose vector test") - - # Verify result + + result = await query.get_vectors(["verbose vector test"]) + + mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"]) assert result == expected_vectors @pytest.mark.asyncio - async def test_document_rag_integration_flow(self): + async def test_document_rag_integration_flow(self, mock_fetch_chunk): """Test complete DocumentRag integration with realistic data flow""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - - # Mock realistic responses + query_text = "What is machine learning?" - query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] - retrieved_docs = [ - "Machine learning is a subset of artificial intelligence...", - "ML algorithms learn patterns from data to make predictions...", - "Common ML techniques include supervised and unsupervised learning..." - ] final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." - + + # Mock concept extraction + mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence" + + # Mock embeddings - one vector per concept + query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]] mock_embeddings_client.embed.return_value = query_vectors - mock_doc_embeddings_client.query.return_value = retrieved_docs + + # Each concept query returns some matches + mock_matches_1 = [ + MagicMock(chunk_id="doc/ml1", score=0.9), + MagicMock(chunk_id="doc/ml2", score=0.85), + ] + mock_matches_2 = [ + MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate + MagicMock(chunk_id="doc/ml3", score=0.82), + ] + mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2] mock_prompt_client.document_prompt.return_value = final_response - - # Initialize DocumentRag + document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - - # Execute full pipeline + result = await document_rag.query( query=query_text, - user="research_user", + user="research_user", collection="ml_knowledge", doc_limit=25 ) - - # Verify complete pipeline execution - mock_embeddings_client.embed.assert_called_once_with(query_text) - - mock_doc_embeddings_client.query.assert_called_once_with( - query_vectors, - limit=25, - user="research_user", - collection="ml_knowledge" + + # Verify concept extraction + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": query_text} ) - - mock_prompt_client.document_prompt.assert_called_once_with( - query=query_text, - documents=retrieved_docs + + # Verify embeddings called with concepts + mock_embeddings_client.embed.assert_called_once_with( + ["machine learning", "artificial intelligence"] ) - - # Verify final result - assert result == final_response \ No newline at end of file + + # Verify two per-concept queries were made (25 // 2 = 12 per concept) + assert mock_doc_embeddings_client.query.call_count == 2 + + # Verify prompt client was called with fetched document content + mock_prompt_client.document_prompt.assert_called_once() + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == query_text + + # Verify documents were fetched and deduplicated + docs = call_args.kwargs["documents"] + assert "Machine learning is a subset of artificial intelligence..." in docs + assert "ML algorithms learn patterns from data to make predictions..." in docs + assert "Common ML techniques include supervised and unsupervised learning..." in docs + assert len(docs) == 3 # doc/ml2 deduplicated + + assert result == final_response + + @pytest.mark.asyncio + async def test_get_docs_deduplicates_across_concepts(self): + """Test that get_docs deduplicates chunks across multiple concepts""" + mock_rag = MagicMock() + mock_embeddings_client = AsyncMock() + mock_doc_embeddings_client = AsyncMock() + mock_rag.embeddings_client = mock_embeddings_client + mock_rag.doc_embeddings_client = mock_doc_embeddings_client + + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + + # Two concepts → two vectors + mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # Both queries return overlapping chunks + match_a = MagicMock(chunk_id="doc/c1", score=0.9) + match_b = MagicMock(chunk_id="doc/c2", score=0.8) + match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate + mock_doc_embeddings_client.query.side_effect = [ + [match_a, match_b], + [match_c], + ] + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False, + doc_limit=10 + ) + + docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) + + assert len(chunk_ids) == 2 # doc/c1 only counted once + assert "doc/c1" in chunk_ids + assert "doc/c2" in chunk_ids diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index 041d29df..05e1bb60 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -5,7 +5,7 @@ passed to the DocumentRag.query() method. """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock, patch, ANY from trustgraph.retrieval.document_rag.rag import Processor from trustgraph.schema import DocumentRagQuery, DocumentRagResponse @@ -65,8 +65,10 @@ class TestDocumentRagService: mock_rag_instance.query.assert_called_once_with( "test query", user="my_user", # Must be from message, not hardcoded default - collection="test_coll_1", # Must be from message, not hardcoded default - doc_limit=5 + collection="test_coll_1", # Must be from message, not hardcoded default + doc_limit=5, + explain_callback=ANY, # Explainability callback is always passed + save_answer_callback=ANY, # Librarian save callback is always passed ) # Verify response was sent diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 5f54e28a..597d3366 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -19,7 +19,7 @@ class TestGraphRag: mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() - + # Initialize GraphRag graph_rag = GraphRag( prompt_client=mock_prompt_client, @@ -27,7 +27,7 @@ class TestGraphRag: graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client ) - + # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client @@ -45,7 +45,7 @@ class TestGraphRag: mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() - + # Initialize GraphRag with verbose=True graph_rag = GraphRag( prompt_client=mock_prompt_client, @@ -54,7 +54,7 @@ class TestGraphRag: triples_client=mock_triples_client, verbose=True ) - + # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client @@ -73,7 +73,7 @@ class TestQuery: """Test Query initialization with default parameters""" # Create mock GraphRag mock_rag = MagicMock() - + # Initialize Query with defaults query = Query( rag=mock_rag, @@ -81,7 +81,7 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "test_user" @@ -96,7 +96,7 @@ class TestQuery: """Test Query initialization with custom parameters""" # Create mock GraphRag mock_rag = MagicMock() - + # Initialize Query with custom parameters query = Query( rag=mock_rag, @@ -108,7 +108,7 @@ class TestQuery: max_subgraph_size=2000, max_path_length=3 ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "custom_user" @@ -120,87 +120,127 @@ class TestQuery: assert query.max_path_length == 3 @pytest.mark.asyncio - async def test_get_vector_method(self): - """Test Query.get_vector method calls embeddings client correctly""" - # Create mock GraphRag with embeddings client + async def test_get_vectors_method(self): + """Test Query.get_vectors method calls embeddings client correctly""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method to return test vectors + + # Mock embed to return vectors for a list of concepts expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] mock_embeddings_client.embed.return_value = expected_vectors - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call get_vector - test_query = "What is the capital of France?" - result = await query.get_vector(test_query) - - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify result matches expected vectors + + concepts = ["machine learning", "neural networks"] + result = await query.get_vectors(concepts) + + mock_embeddings_client.embed.assert_called_once_with(concepts) assert result == expected_vectors @pytest.mark.asyncio - async def test_get_vector_method_with_verbose(self): - """Test Query.get_vector method with verbose output""" - # Create mock GraphRag with embeddings client + async def test_get_vectors_method_with_verbose(self): + """Test Query.get_vectors method with verbose output""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method + expected_vectors = [[0.7, 0.8, 0.9]] mock_embeddings_client.embed.return_value = expected_vectors - - # Initialize Query with verbose=True + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=True ) - - # Call get_vector - test_query = "Test query for embeddings" - result = await query.get_vector(test_query) - - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify result matches expected vectors + + result = await query.get_vectors(["test concept"]) + + mock_embeddings_client.embed.assert_called_once_with(["test concept"]) assert result == expected_vectors @pytest.mark.asyncio - async def test_get_entities_method(self): - """Test Query.get_entities method retrieves entities correctly""" - # Create mock GraphRag with clients + async def test_extract_concepts(self): + """Test Query.extract_concepts parses LLM response into concept list""" mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("What is machine learning?") + + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "What is machine learning?"} + ) + assert result == ["machine learning", "neural networks"] + + @pytest.mark.asyncio + async def test_extract_concepts_fallback_to_raw_query(self): + """Test extract_concepts falls back to raw query when LLM returns empty""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + mock_prompt_client.prompt.return_value = "" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("test query") + assert result == ["test query"] + + @pytest.mark.asyncio + async def test_get_entities_method(self): + """Test Query.get_entities extracts concepts, embeds, and retrieves entities""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client mock_rag.embeddings_client = mock_embeddings_client mock_rag.graph_embeddings_client = mock_graph_embeddings_client - - # Mock the embedding and entity query responses + + # extract_concepts returns empty -> falls back to [query] + mock_prompt_client.prompt.return_value = "" + + # embed returns one vector set for the single concept test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = test_vectors - - # Mock entity objects that have string representation + + # Mock entity matches mock_entity1 = MagicMock() - mock_entity1.__str__ = MagicMock(return_value="entity1") + mock_entity1.type = "i" + mock_entity1.iri = "entity1" + mock_match1 = MagicMock() + mock_match1.entity = mock_entity1 + mock_entity2 = MagicMock() - mock_entity2.__str__ = MagicMock(return_value="entity2") - mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] - - # Initialize Query + mock_entity2.type = "i" + mock_entity2.iri = "entity2" + mock_match2 = MagicMock() + mock_match2.entity = mock_entity2 + + mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2] + query = Query( rag=mock_rag, user="test_user", @@ -208,36 +248,24 @@ class TestQuery: verbose=False, entity_limit=25 ) - - # Call get_entities - test_query = "Find related entities" - result = await query.get_entities(test_query) - - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify graph embeddings client was called correctly - mock_graph_embeddings_client.query.assert_called_once_with( - vectors=test_vectors, - limit=25, - user="test_user", - collection="test_collection" - ) - - # Verify result is list of entity strings - assert result == ["entity1", "entity2"] + + entities, concepts = await query.get_entities("Find related entities") + + # Verify embeddings client was called with the fallback concept + mock_embeddings_client.embed.assert_called_once_with(["Find related entities"]) + + # Verify result + assert entities == ["entity1", "entity2"] + assert concepts == ["Find related entities"] @pytest.mark.asyncio async def test_maybe_label_with_cached_label(self): """Test Query.maybe_label method with cached label""" - # Create mock GraphRag with label cache mock_rag = MagicMock() - # Create mock LRUCacheWithTTL mock_cache = MagicMock() mock_cache.get.return_value = "Entity One Label" mock_rag.label_cache = mock_cache - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -245,32 +273,25 @@ class TestQuery: verbose=False ) - # Call maybe_label with cached entity result = await query.maybe_label("entity1") - # Verify cached label is returned assert result == "Entity One Label" - # Verify cache was checked with proper key format (user:collection:entity) mock_cache.get.assert_called_once_with("test_user:test_collection:entity1") @pytest.mark.asyncio async def test_maybe_label_with_label_lookup(self): """Test Query.maybe_label method with database label lookup""" - # Create mock GraphRag with triples client mock_rag = MagicMock() - # Create mock LRUCacheWithTTL that returns None (cache miss) mock_cache = MagicMock() mock_cache.get.return_value = None mock_rag.label_cache = mock_cache mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - # Mock triple result with label mock_triple = MagicMock() mock_triple.o = "Human Readable Label" mock_triples_client.query.return_value = [mock_triple] - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -278,20 +299,18 @@ class TestQuery: verbose=False ) - # Call maybe_label result = await query.maybe_label("http://example.com/entity") - # Verify triples client was called correctly mock_triples_client.query.assert_called_once_with( s="http://example.com/entity", p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, user="test_user", - collection="test_collection" + collection="test_collection", + g="" ) - # Verify result and cache update with proper key assert result == "Human Readable Label" cache_key = "test_user:test_collection:http://example.com/entity" mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label") @@ -299,40 +318,34 @@ class TestQuery: @pytest.mark.asyncio async def test_maybe_label_with_no_label_found(self): """Test Query.maybe_label method when no label is found""" - # Create mock GraphRag with triples client mock_rag = MagicMock() - # Create mock LRUCacheWithTTL that returns None (cache miss) mock_cache = MagicMock() mock_cache.get.return_value = None mock_rag.label_cache = mock_cache mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Mock empty result (no label found) + mock_triples_client.query.return_value = [] - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call maybe_label + result = await query.maybe_label("unlabeled_entity") - - # Verify triples client was called + mock_triples_client.query.assert_called_once_with( s="unlabeled_entity", p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, user="test_user", - collection="test_collection" + collection="test_collection", + g="" ) - - # Verify result is entity itself and cache is updated + assert result == "unlabeled_entity" cache_key = "test_user:test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") @@ -340,29 +353,25 @@ class TestQuery: @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" - # Create mock GraphRag with triples client mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Mock triple results for different query patterns + mock_triple1 = MagicMock() mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" - + mock_triple2 = MagicMock() mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" - + mock_triple3 = MagicMock() mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - - # Setup query responses for s=ent, p=ent, o=ent patterns - mock_triples_client.query.side_effect = [ - [mock_triple1], # s=ent, p=None, o=None - [mock_triple2], # s=None, p=ent, o=None - [mock_triple3], # s=None, p=None, o=ent + + mock_triples_client.query_stream.side_effect = [ + [mock_triple1], # s=ent + [mock_triple2], # p=ent + [mock_triple3], # o=ent ] - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", @@ -370,29 +379,25 @@ class TestQuery: verbose=False, triple_limit=10 ) - - # Call follow_edges + subgraph = set() await query.follow_edges("entity1", subgraph, path_length=1) - - # Verify all three query patterns were called - assert mock_triples_client.query.call_count == 3 - - # Verify query calls - mock_triples_client.query.assert_any_call( + + assert mock_triples_client.query_stream.call_count == 3 + + mock_triples_client.query_stream.assert_any_call( s="entity1", p=None, o=None, limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20, g="" ) - mock_triples_client.query.assert_any_call( + mock_triples_client.query_stream.assert_any_call( s=None, p="entity1", o=None, limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20, g="" ) - mock_triples_client.query.assert_any_call( + mock_triples_client.query_stream.assert_any_call( s=None, p=None, o="entity1", limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20, g="" ) - - # Verify subgraph contains discovered triples + expected_subgraph = { ("entity1", "predicate1", "object1"), ("subject2", "entity1", "object2"), @@ -403,38 +408,30 @@ class TestQuery: @pytest.mark.asyncio async def test_follow_edges_with_path_length_zero(self): """Test Query.follow_edges method with path_length=0""" - # Create mock GraphRag mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call follow_edges with path_length=0 + subgraph = set() await query.follow_edges("entity1", subgraph, path_length=0) - - # Verify no queries were made - mock_triples_client.query.assert_not_called() - - # Verify subgraph remains empty + + mock_triples_client.query_stream.assert_not_called() assert subgraph == set() @pytest.mark.asyncio async def test_follow_edges_with_max_subgraph_size_limit(self): """Test Query.follow_edges method respects max_subgraph_size""" - # Create mock GraphRag mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Initialize Query with small max_subgraph_size + query = Query( rag=mock_rag, user="test_user", @@ -442,23 +439,17 @@ class TestQuery: verbose=False, max_subgraph_size=2 ) - - # Pre-populate subgraph to exceed limit + subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} - - # Call follow_edges + await query.follow_edges("entity1", subgraph, path_length=1) - - # Verify no queries were made due to size limit - mock_triples_client.query.assert_not_called() - - # Verify subgraph unchanged + + mock_triples_client.query_stream.assert_not_called() assert len(subgraph) == 3 @pytest.mark.asyncio async def test_get_subgraph_method(self): - """Test Query.get_subgraph method orchestrates entity and edge discovery""" - # Create mock Query that patches get_entities and follow_edges_batch + """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple""" mock_rag = MagicMock() query = Query( @@ -469,103 +460,119 @@ class TestQuery: max_path_length=1 ) - # Mock get_entities to return test entities - query.get_entities = AsyncMock(return_value=["entity1", "entity2"]) + # Mock get_entities to return (entities, concepts) tuple + query.get_entities = AsyncMock( + return_value=(["entity1", "entity2"], ["concept1"]) + ) - # Mock follow_edges_batch to return test triples query.follow_edges_batch = AsyncMock(return_value={ ("entity1", "predicate1", "object1"), ("entity2", "predicate2", "object2") }) - # Call get_subgraph - result = await query.get_subgraph("test query") + subgraph, entities, concepts = await query.get_subgraph("test query") - # Verify get_entities was called query.get_entities.assert_called_once_with("test query") - - # Verify follow_edges_batch was called with entities and max_path_length query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) - # Verify result is list format and contains expected triples - assert isinstance(result, list) - assert len(result) == 2 - assert ("entity1", "predicate1", "object1") in result - assert ("entity2", "predicate2", "object2") in result + assert isinstance(subgraph, list) + assert len(subgraph) == 2 + assert ("entity1", "predicate1", "object1") in subgraph + assert ("entity2", "predicate2", "object2") in subgraph + assert entities == ["entity1", "entity2"] + assert concepts == ["concept1"] @pytest.mark.asyncio async def test_get_labelgraph_method(self): - """Test Query.get_labelgraph method converts entities to labels""" - # Create mock Query + """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)""" mock_rag = MagicMock() - + query = Query( rag=mock_rag, user="test_user", - collection="test_collection", + collection="test_collection", verbose=False, max_subgraph_size=100 ) - - # Mock get_subgraph to return test triples + test_subgraph = [ ("entity1", "predicate1", "object1"), - ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered + ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), ("entity3", "predicate3", "object3") ] - query.get_subgraph = AsyncMock(return_value=test_subgraph) - - # Mock maybe_label to return human-readable labels + test_entities = ["entity1", "entity3"] + test_concepts = ["concept1"] + query.get_subgraph = AsyncMock( + return_value=(test_subgraph, test_entities, test_concepts) + ) + async def mock_maybe_label(entity): label_map = { "entity1": "Human Entity One", - "predicate1": "Human Predicate One", + "predicate1": "Human Predicate One", "object1": "Human Object One", "entity3": "Human Entity Three", "predicate3": "Human Predicate Three", "object3": "Human Object Three" } return label_map.get(entity, entity) - + query.maybe_label = AsyncMock(side_effect=mock_maybe_label) - - # Call get_labelgraph - result = await query.get_labelgraph("test query") - - # Verify get_subgraph was called + + labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query") + query.get_subgraph.assert_called_once_with("test query") - - # Verify label triples are filtered out - assert len(result) == 2 # Label triple should be excluded - - # Verify maybe_label was called for non-label triples - expected_calls = [ - (("entity1",), {}), (("predicate1",), {}), (("object1",), {}), - (("entity3",), {}), (("predicate3",), {}), (("object3",), {}) - ] + + # Label triples filtered out + assert len(labeled_edges) == 2 + + # maybe_label called for non-label triples assert query.maybe_label.call_count == 6 - - # Verify result contains human-readable labels - expected_result = [ + + expected_edges = [ ("Human Entity One", "Human Predicate One", "Human Object One"), ("Human Entity Three", "Human Predicate Three", "Human Object Three") ] - assert result == expected_result + assert labeled_edges == expected_edges + + assert len(uri_map) == 2 + assert entities == test_entities + assert concepts == test_concepts @pytest.mark.asyncio async def test_graph_rag_query_method(self): - """Test GraphRag.query method orchestrates full RAG pipeline""" - # Create mock clients + """Test GraphRag.query method orchestrates full RAG pipeline with provenance""" + import json + from trustgraph.retrieval.graph_rag.graph_rag import edge_id + mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() - - # Mock prompt client response + expected_response = "This is the RAG response" - mock_prompt_client.kg_prompt.return_value = expected_response - - # Initialize GraphRag + test_labelgraph = [("Subject", "Predicate", "Object")] + test_edge_id = edge_id("Subject", "Predicate", "Object") + test_uri_map = { + test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + } + test_entities = ["http://example.org/subject"] + test_concepts = ["test concept"] + + # Mock prompt responses for the multi-step process + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): + if prompt_name == "extract-concepts": + return "" # Falls back to raw query + elif prompt_name == "kg-edge-scoring": + return json.dumps({"id": test_edge_id, "score": 0.9}) + elif prompt_name == "kg-edge-reasoning": + return json.dumps({"id": test_edge_id, "reasoning": "relevant"}) + elif prompt_name == "kg-synthesis": + return expected_response + return "" + + mock_prompt_client.prompt = mock_prompt + graph_rag = GraphRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -573,40 +580,46 @@ class TestQuery: triples_client=mock_triples_client, verbose=False ) - - # Mock the Query class behavior by patching get_labelgraph - test_labelgraph = [("Subject", "Predicate", "Object")] - - # We need to patch the Query class's get_labelgraph method - original_query_init = Query.__init__ + + # Patch Query.get_labelgraph to return test data original_get_labelgraph = Query.get_labelgraph - - def mock_query_init(self, *args, **kwargs): - original_query_init(self, *args, **kwargs) - + async def mock_get_labelgraph(self, query_text): - return test_labelgraph - - Query.__init__ = mock_query_init + return test_labelgraph, test_uri_map, test_entities, test_concepts + Query.get_labelgraph = mock_get_labelgraph - + + provenance_events = [] + + async def collect_provenance(triples, prov_id): + provenance_events.append((triples, prov_id)) + try: - # Call GraphRag.query - result = await graph_rag.query( + response = await graph_rag.query( query="test query", user="test_user", collection="test_collection", entity_limit=25, - triple_limit=15 + triple_limit=15, + explain_callback=collect_provenance ) - - # Verify prompt client was called with knowledge graph and query - mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph) - - # Verify result - assert result == expected_response - + + assert response == expected_response + + # 5 events: question, grounding, exploration, focus, synthesis + assert len(provenance_events) == 5 + + for triples, prov_id in provenance_events: + assert isinstance(triples, list) + assert len(triples) > 0 + assert prov_id.startswith("urn:trustgraph:") + + # Verify order + assert "question" in provenance_events[0][1] + assert "grounding" in provenance_events[1][1] + assert "exploration" in provenance_events[2][1] + assert "focus" in provenance_events[3][1] + assert "synthesis" in provenance_events[4][1] + finally: - # Restore original methods - Query.__init__ = original_query_init - Query.get_labelgraph = original_get_labelgraph \ No newline at end of file + Query.get_labelgraph = original_get_labelgraph diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index ddfdfa75..2cd62286 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -1,6 +1,7 @@ """ -Unit tests for GraphRAG service non-streaming mode. -Tests that end_of_stream flag is correctly set in non-streaming responses. +Unit tests for GraphRAG service message format. +Tests the new message protocol with message_type, explain_id, and end_of_session. +Real-time explainability emission via callback. """ import pytest @@ -11,16 +12,14 @@ from trustgraph.schema import GraphRagQuery, GraphRagResponse class TestGraphRagService: - """Test GraphRAG service non-streaming behavior""" + """Test GraphRAG service message protocol""" @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') @pytest.mark.asyncio - async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class): + async def test_non_streaming_sends_chunk_then_provenance_messages(self, mock_graph_rag_class): """ - Test that non-streaming mode sets end_of_stream=True in response. - - This is a regression test for the bug where non-streaming responses - didn't set end_of_stream, causing clients to hang waiting for more data. + Test that non-streaming mode sends real-time provenance messages + followed by chunk message with response. """ # Setup processor processor = Processor( @@ -32,10 +31,22 @@ class TestGraphRagService: max_path_length=2 ) - # Setup mock GraphRag instance + # Setup mock GraphRag instance that calls explain_callback mock_rag_instance = AsyncMock() mock_graph_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "A small domesticated mammal." + + # Mock query() to call the explain_callback with each provenance event + async def mock_query(**kwargs): + explain_callback = kwargs.get('explain_callback') + if explain_callback: + # Simulate real-time provenance emission + await explain_callback([], "urn:trustgraph:session:test") + await explain_callback([], "urn:trustgraph:prov:retrieval:test") + await explain_callback([], "urn:trustgraph:prov:selection:test") + await explain_callback([], "urn:trustgraph:prov:answer:test") + return "A small domesticated mammal." + + mock_rag_instance.query.side_effect = mock_query # Setup message with non-streaming request msg = MagicMock() @@ -47,7 +58,7 @@ class TestGraphRagService: triple_limit=30, max_subgraph_size=150, max_path_length=2, - streaming=False # Non-streaming mode + streaming=False ) msg.properties.return_value = {"id": "test-id"} @@ -55,30 +66,48 @@ class TestGraphRagService: consumer = MagicMock() flow = MagicMock() - # Mock flow to return AsyncMock for clients and response producer - mock_producer = AsyncMock() + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() def flow_router(service_name): if service_name == "response": - return mock_producer - return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer + return AsyncMock() flow.side_effect = flow_router # Execute await processor.on_request(msg, consumer, flow) - # Verify: response was sent with end_of_stream=True - mock_producer.send.assert_called_once() - sent_response = mock_producer.send.call_args[0][0] - assert isinstance(sent_response, GraphRagResponse) - assert sent_response.response == "A small domesticated mammal." - assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" - assert sent_response.error is None + # Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session) + assert mock_response_producer.send.call_count == 6 + + # First 4 messages are explain (emitted in real-time during query) + for i in range(4): + prov_msg = mock_response_producer.send.call_args_list[i][0][0] + assert prov_msg.message_type == "explain" + assert prov_msg.explain_id is not None + + # 5th message is chunk with response + chunk_msg = mock_response_producer.send.call_args_list[4][0][0] + assert chunk_msg.message_type == "chunk" + assert chunk_msg.response == "A small domesticated mammal." + assert chunk_msg.end_of_stream is True + + # 6th message is empty chunk with end_of_session=True + close_msg = mock_response_producer.send.call_args_list[5][0][0] + assert close_msg.message_type == "chunk" + assert close_msg.response == "" + assert close_msg.end_of_session is True + + # Verify provenance triples were sent to provenance queue + assert mock_provenance_producer.send.call_count == 4 @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') @pytest.mark.asyncio - async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class): + async def test_error_response_closes_session(self, mock_graph_rag_class): """ - Test that error responses in non-streaming mode set end_of_stream=True. + Test that error responses set end_of_session=True. """ # Setup processor processor = Processor( @@ -105,7 +134,7 @@ class TestGraphRagService: triple_limit=30, max_subgraph_size=150, max_path_length=2, - streaming=False # Non-streaming mode + streaming=False ) msg.properties.return_value = {"id": "test-id"} @@ -113,22 +142,93 @@ class TestGraphRagService: consumer = MagicMock() flow = MagicMock() - mock_producer = AsyncMock() + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() def flow_router(service_name): if service_name == "response": - return mock_producer + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer return AsyncMock() flow.side_effect = flow_router # Execute await processor.on_request(msg, consumer, flow) - # Verify: error response was sent without end_of_stream (not streaming mode) - mock_producer.send.assert_called_once() - sent_response = mock_producer.send.call_args[0][0] + # Verify: error response was sent with session closed + mock_response_producer.send.assert_called_once() + sent_response = mock_response_producer.send.call_args[0][0] assert isinstance(sent_response, GraphRagResponse) - assert sent_response.response is None + assert sent_response.message_type == "chunk" assert sent_response.error is not None assert sent_response.error.message == "Test error" - # Note: error responses in non-streaming mode don't set end_of_stream - # because streaming was never started + assert sent_response.end_of_stream is True + assert sent_response.end_of_session is True + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_no_provenance_sends_empty_chunk_to_close(self, mock_graph_rag_class): + """ + Test that when no provenance callback is invoked, an empty chunk closes the session. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2 + ) + + # Setup mock GraphRag instance that doesn't call provenance callback + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + + async def mock_query(**kwargs): + # Don't call explain_callback + return "Response text" + + mock_rag_instance.query.side_effect = mock_query + + # Setup message + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="Test query", + user="trustgraph", + collection="default", + streaming=False + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_response_producer = AsyncMock() + mock_provenance_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_response_producer + elif service_name == "explainability": + return mock_provenance_producer + return AsyncMock() + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: 2 messages (chunk + empty chunk to close) + assert mock_response_producer.send.call_count == 2 + + # First is the response chunk + chunk_msg = mock_response_producer.send.call_args_list[0][0][0] + assert chunk_msg.message_type == "chunk" + assert chunk_msg.response == "Response text" + assert chunk_msg.end_of_stream is True + + # Second is empty chunk to close session + close_msg = mock_response_producer.send.call_args_list[1][0][0] + assert close_msg.message_type == "chunk" + assert close_msg.response == "" + assert close_msg.end_of_session is True diff --git a/tests/unit/test_storage/conftest.py b/tests/unit/test_storage/conftest.py index 594e2b2f..32c210b2 100644 --- a/tests/unit/test_storage/conftest.py +++ b/tests/unit/test_storage/conftest.py @@ -53,7 +53,7 @@ def mock_document_embeddings_message(): mock_chunk = MagicMock() mock_chunk.chunk.decode.return_value = 'test document chunk' - mock_chunk.vectors = [[0.1, 0.2, 0.3]] + mock_chunk.vector = [0.1, 0.2, 0.3] mock_message.chunks = [mock_chunk] return mock_message @@ -68,11 +68,11 @@ def mock_document_embeddings_multiple_chunks(): mock_chunk1 = MagicMock() mock_chunk1.chunk.decode.return_value = 'first document chunk' - mock_chunk1.vectors = [[0.1, 0.2]] - + mock_chunk1.vector = [0.1, 0.2] + mock_chunk2 = MagicMock() mock_chunk2.chunk.decode.return_value = 'second document chunk' - mock_chunk2.vectors = [[0.3, 0.4]] + mock_chunk2.vector = [0.3, 0.4] mock_message.chunks = [mock_chunk1, mock_chunk2] return mock_message @@ -87,11 +87,7 @@ def mock_document_embeddings_multiple_vectors(): mock_chunk = MagicMock() mock_chunk.chunk.decode.return_value = 'multi-vector document chunk' - mock_chunk.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] + mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] mock_message.chunks = [mock_chunk] return mock_message @@ -106,7 +102,7 @@ def mock_document_embeddings_empty_chunk(): mock_chunk = MagicMock() mock_chunk.chunk.decode.return_value = "" # Empty string - mock_chunk.vectors = [[0.1, 0.2]] + mock_chunk.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk] return mock_message @@ -122,7 +118,7 @@ def mock_graph_embeddings_message(): mock_entity = MagicMock() mock_entity.entity.value = 'test_entity' - mock_entity.vectors = [[0.1, 0.2, 0.3]] + mock_entity.vector = [0.1, 0.2, 0.3] mock_message.entities = [mock_entity] return mock_message @@ -137,11 +133,11 @@ def mock_graph_embeddings_multiple_entities(): mock_entity1 = MagicMock() mock_entity1.entity.value = 'entity_one' - mock_entity1.vectors = [[0.1, 0.2]] - + mock_entity1.vector = [0.1, 0.2] + mock_entity2 = MagicMock() mock_entity2.entity.value = 'entity_two' - mock_entity2.vectors = [[0.3, 0.4]] + mock_entity2.vector = [0.3, 0.4] mock_message.entities = [mock_entity1, mock_entity2] return mock_message @@ -156,7 +152,7 @@ def mock_graph_embeddings_empty_entity(): mock_entity = MagicMock() mock_entity.entity.value = "" # Empty string - mock_entity.vectors = [[0.1, 0.2]] + mock_entity.vector = [0.1, 0.2] mock_message.entities = [mock_entity] return mock_message \ No newline at end of file diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index d957d711..f9d60541 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -22,12 +22,12 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Create test document embeddings chunk1 = ChunkEmbeddings( - chunk=b"This is the first document chunk", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + chunk_id="This is the first document chunk", + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) chunk2 = ChunkEmbeddings( - chunk=b"This is the second document chunk", - vectors=[[0.7, 0.8, 0.9]] + chunk_id="This is the second document chunk", + vector=[0.7, 0.8, 0.9] ) message.chunks = [chunk1, chunk2] @@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( - chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + chunk_id="Test document content", + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_doc - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once for the single chunk with its vector + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection' + ) @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" await processor.store_document_embeddings(mock_message) - - # Verify insert was called for each vector of each chunk with user/collection parameters + + # Verify insert was called once per chunk with user/collection parameters expected_calls = [ - # Chunk 1 vectors - ([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), - # Chunk 2 vectors + # Chunk 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), + # Chunk 2 - single vector ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec @@ -136,8 +126,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"", - vectors=[[0.1, 0.2, 0.3]] + chunk_id="", + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -148,51 +138,62 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_none_chunk(self, processor): - """Test storing document embeddings with None chunk (should be skipped)""" + """Test storing document embeddings with None chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( - chunk=None, - vectors=[[0.1, 0.2, 0.3]] + chunk_id=None, + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify no insert was called for None chunk - processor.vecstore.insert.assert_not_called() + + # Note: Implementation passes through None chunk_ids (only skips empty string "") + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], None, 'test_user', 'test_collection' + ) @pytest.mark.asyncio async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor): - """Test storing document embeddings with mix of valid and invalid chunks""" + """Test storing document embeddings with mix of valid and empty chunks""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + valid_chunk = ChunkEmbeddings( - chunk=b"Valid document content", - vectors=[[0.1, 0.2, 0.3]] + chunk_id="Valid document content", + vector=[0.1, 0.2, 0.3] ) empty_chunk = ChunkEmbeddings( - chunk=b"", - vectors=[[0.4, 0.5, 0.6]] + chunk_id="", + vector=[0.4, 0.5, 0.6] ) - none_chunk = ChunkEmbeddings( - chunk=None, - vectors=[[0.7, 0.8, 0.9]] + another_valid = ChunkEmbeddings( + chunk_id="Another valid chunk", + vector=[0.7, 0.8, 0.9] ) - message.chunks = [valid_chunk, empty_chunk, none_chunk] - + message.chunks = [valid_chunk, empty_chunk, another_valid] + await processor.store_document_embeddings(message) - - # Verify only valid chunk was inserted with user/collection parameters - processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection' - ) + + # Verify valid chunks were inserted, empty string chunk was skipped + expected_calls = [ + ([0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "Another valid chunk", 'test_user', 'test_collection'), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_chunk_id, expected_user, expected_collection) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_chunk_id + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunks_list(self, processor): @@ -217,8 +218,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"Document with no vectors", - vectors=[] + chunk_id="Document with no vectors", + vector=[] ) message.chunks = [chunk] @@ -234,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - chunk = ChunkEmbeddings( - chunk=b"Document with mixed dimensions", - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + + # Each chunk has a single vector of different dimensions + chunk1 = ChunkEmbeddings( + chunk_id="chunk/doc/2d", + vector=[0.1, 0.2] # 2D vector ) - message.chunks = [chunk] - + chunk2 = ChunkEmbeddings( + chunk_id="chunk/doc/4d", + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + chunk3 = ChunkEmbeddings( + chunk_id="chunk/doc/3d", + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.chunks = [chunk1, chunk2, chunk3] + await processor.store_document_embeddings(message) - + # Verify all vectors were inserted regardless of dimension with user/collection parameters expected_calls = [ - ([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'), - ([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'), - ([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'), + ([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'), + ([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'), ] - + assert processor.vecstore.insert.call_count == 3 for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] @@ -264,46 +270,46 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_unicode_content(self, processor): - """Test storing document embeddings with Unicode content""" + """Test storing document embeddings with Unicode content in chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( - chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + chunk_id="chunk/doc/unicode-éñ中文🚀", + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify Unicode content was properly decoded and inserted with user/collection parameters + + # Verify Unicode chunk_id was stored correctly with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection' + [0.1, 0.2, 0.3], "chunk/doc/unicode-éñ中文🚀", 'test_user', 'test_collection' ) @pytest.mark.asyncio - async def test_store_document_embeddings_large_chunks(self, processor): - """Test storing document embeddings with large document chunks""" + async def test_store_document_embeddings_large_chunk_id(self, processor): + """Test storing document embeddings with long chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - # Create a large document chunk - large_content = "A" * 10000 # 10KB of content + + # Create a long chunk_id + long_chunk_id = "chunk/doc/" + "a" * 200 chunk = ChunkEmbeddings( - chunk=large_content.encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + chunk_id=long_chunk_id, + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify large content was inserted with user/collection parameters + + # Verify long chunk_id was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection' + [0.1, 0.2, 0.3], long_chunk_id, 'test_user', 'test_collection' ) @pytest.mark.asyncio @@ -315,8 +321,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b" \n\t ", - vectors=[[0.1, 0.2, 0.3]] + chunk_id=" \n\t ", + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -346,8 +352,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = collection chunk = ChunkEmbeddings( - chunk=b"Test content", - vectors=[[0.1, 0.2, 0.3]] + chunk_id="Test content", + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -367,8 +373,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message1.metadata.user = 'user1' message1.metadata.collection = 'collection1' chunk1 = ChunkEmbeddings( - chunk=b"User1 content", - vectors=[[0.1, 0.2, 0.3]] + chunk_id="User1 content", + vector=[0.1, 0.2, 0.3] ) message1.chunks = [chunk1] @@ -378,8 +384,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message2.metadata.user = 'user2' message2.metadata.collection = 'collection2' chunk2 = ChunkEmbeddings( - chunk=b"User2 content", - vectors=[[0.4, 0.5, 0.6]] + chunk_id="User2 content", + vector=[0.4, 0.5, 0.6] ) message2.chunks = [chunk2] @@ -409,8 +415,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test-collection.v1' # Collection with special chars chunk = ChunkEmbeddings( - chunk=b"Special chars test", - vectors=[[0.1, 0.2, 0.3]] + chunk_id="Special chars test", + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index fc7c0a79..fec4f87e 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor: # Create test document embeddings chunk1 = ChunkEmbeddings( chunk=b"This is the first document chunk", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) chunk2 = ChunkEmbeddings( chunk=b"This is the second document chunk", - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.chunks = [chunk1, chunk2] @@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.chunks = [chunk] @@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=None, - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"", # Empty bytes - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor: message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - chunk = ChunkEmbeddings( - chunk=b"Document with mixed dimensions", - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + # Each chunk has a single vector of different dimensions + chunk1 = ChunkEmbeddings( + chunk=b"Document chunk 1", + vector=[0.1, 0.2] # 2D vector ) - message.chunks = [chunk] - - mock_index_2d = MagicMock() - mock_index_4d = MagicMock() - mock_index_3d = MagicMock() - + chunk2 = ChunkEmbeddings( + chunk=b"Document chunk 2", + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + chunk3 = ChunkEmbeddings( + chunk=b"Document chunk 3", + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.chunks = [chunk1, chunk2, chunk3] + + mock_index = MagicMock() + def mock_index_side_effect(name): # All dimensions now use the same index name pattern - # Different dimensions will be handled within the same index if "test_user" in name and "test_collection" in name: - return mock_index_2d # Just return one mock for all + return mock_index return MagicMock() - + processor.pinecone.Index.side_effect = mock_index_side_effect processor.pinecone.has_index.return_value = True - + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): await processor.store_document_embeddings(message) # Verify all vectors are now stored in the same index - # (Pinecone can handle mixed dimensions in the same index) - assert processor.pinecone.Index.call_count == 3 # Called once per vector - mock_index_2d.upsert.call_count == 3 # All upserts go to same index + # (Each chunk has a single vector, called once per chunk) + assert processor.pinecone.Index.call_count == 3 # Called once per chunk + assert mock_index.upsert.call_count == 3 # All upserts go to same index @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunks_list(self, processor): @@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Document with no vectors", - vectors=[] + vector=[] ) message.chunks = [chunk] @@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: large_content = "A" * 10000 # 10KB of content chunk = ChunkEmbeddings( chunk=large_content.encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index fc839482..98d2dab2 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -20,7 +20,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -34,7 +34,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert # Verify QdrantClient was created with correct parameters mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') - + # Verify processor attributes assert hasattr(processor, 'qdrant') assert processor.qdrant == mock_qdrant_instance @@ -45,7 +45,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - + config = { 'taskgroup': AsyncMock(), 'id': 'test-doc-qdrant-processor' @@ -69,7 +69,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -86,13 +86,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' - + mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test document chunk' - mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions - + mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes + mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) @@ -100,18 +100,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Verify collection existence was checked (with dimension suffix) expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3] mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) - + # Verify upsert was called mock_qdrant_instance.upsert.assert_called_once() - + # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3' assert len(upsert_call_args[1]['points']) == 1 - + point = upsert_call_args[1]['points'][0] assert point.vector == [0.1, 0.2, 0.3] - assert point.payload['doc'] == 'test document chunk' + assert point.payload['chunk_id'] == 'doc/c1' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @@ -123,7 +123,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -140,50 +140,50 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' - + mock_chunk1 = MagicMock() - mock_chunk1.chunk.decode.return_value = 'first document chunk' - mock_chunk1.vectors = [[0.1, 0.2]] - + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2] + mock_chunk2 = MagicMock() - mock_chunk2.chunk.decode.return_value = 'second document chunk' - mock_chunk2.vectors = [[0.3, 0.4]] - + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.3, 0.4] + mock_message.chunks = [mock_chunk1, mock_chunk2] - + # Act await processor.store_document_embeddings(mock_message) # Assert # Should be called twice (once per chunk) assert mock_qdrant_instance.upsert.call_count == 2 - + # Verify both chunks were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - + # First chunk first_call = upsert_calls[0] first_point = first_call[1]['points'][0] assert first_point.vector == [0.1, 0.2] - assert first_point.payload['doc'] == 'first document chunk' - + assert first_point.payload['chunk_id'] == 'doc/c1' + # Second chunk second_call = upsert_calls[1] second_point = second_call[1]['points'][0] assert second_point.vector == [0.3, 0.4] - assert second_point.payload['doc'] == 'second document chunk' + assert second_point.payload['chunk_id'] == 'doc/c2' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client): - """Test storing document embeddings with multiple vectors per chunk""" + async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client): + """Test storing document embeddings with multiple chunks""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -196,45 +196,49 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('vector_user', 'vector_collection')] = {} - # Create mock message with chunk having multiple vectors + # Create mock message with multiple chunks, each having a single vector mock_message = MagicMock() mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' - - mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'multi-vector document chunk' - mock_chunk.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] - - mock_message.chunks = [mock_chunk] - + + mock_chunk1 = MagicMock() + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2, 0.3] + + mock_chunk2 = MagicMock() + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.4, 0.5, 0.6] + + mock_chunk3 = MagicMock() + mock_chunk3.chunk_id = 'doc/c3' + mock_chunk3.vector = [0.7, 0.8, 0.9] + + mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3] + # Act await processor.store_document_embeddings(mock_message) # Assert - # Should be called 3 times (once per vector) + # Should be called 3 times (once per chunk) assert mock_qdrant_instance.upsert.call_count == 3 - + # Verify all vectors were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - - expected_vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] + + expected_data = [ + ([0.1, 0.2, 0.3], 'doc/c1'), + ([0.4, 0.5, 0.6], 'doc/c2'), + ([0.7, 0.8, 0.9], 'doc/c3') ] - + for i, call in enumerate(upsert_calls): point = call[1]['points'][0] - assert point.vector == expected_vectors[i] - assert point.payload['doc'] == 'multi-vector document chunk' + assert point.vector == expected_data[i][0] + assert point.payload['chunk_id'] == expected_data[i][1] @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client): - """Test storing document embeddings skips empty chunks""" + async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client): + """Test storing document embeddings skips empty chunk_ids""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True # Collection exists @@ -249,14 +253,14 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - # Create mock message with empty chunk + # Create mock message with empty chunk_id mock_message = MagicMock() mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_chunk_empty = MagicMock() - mock_chunk_empty.chunk.decode.return_value = "" # Empty string - mock_chunk_empty.vectors = [[0.1, 0.2]] + mock_chunk_empty.chunk_id = "" # Empty chunk_id + mock_chunk_empty.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk_empty] @@ -264,9 +268,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) # Assert - # Should not call upsert for empty chunks + # Should not call upsert for empty chunk_ids mock_qdrant_instance.upsert.assert_not_called() - # collection_exists should NOT be called since we return early for empty chunks + # collection_exists should NOT be called since we return early for empty chunk_ids mock_qdrant_instance.collection_exists.assert_not_called() @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @@ -298,8 +302,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'new_collection' mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test chunk' - mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions + mock_chunk.chunk_id = 'doc/test-chunk' + mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions mock_message.chunks = [mock_chunk] @@ -350,8 +354,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'error_collection' mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test chunk' - mock_chunk.vectors = [[0.1, 0.2]] + mock_chunk.chunk_id = 'doc/test-chunk' + mock_chunk.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk] @@ -388,8 +392,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message1.metadata.collection = 'cache_collection' mock_chunk1 = MagicMock() - mock_chunk1.chunk.decode.return_value = 'first chunk' - mock_chunk1.vectors = [[0.1, 0.2, 0.3]] + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2, 0.3] mock_message1.chunks = [mock_chunk1] @@ -406,8 +410,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message2.metadata.collection = 'cache_collection' mock_chunk2 = MagicMock() - mock_chunk2.chunk.decode.return_value = 'second chunk' - mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3) + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3) mock_message2.chunks = [mock_chunk2] @@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('dim_user', 'dim_collection')] = {} - # Create mock message with different dimension vectors + # Create mock message with chunks of different dimensions mock_message = MagicMock() mock_message.metadata.user = 'dim_user' mock_message.metadata.collection = 'dim_collection' - mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'dimension test chunk' - mock_chunk.vectors = [ - [0.1, 0.2], # 2 dimensions - [0.3, 0.4, 0.5] # 3 dimensions - ] + mock_chunk1 = MagicMock() + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2] # 2 dimensions - mock_message.chunks = [mock_chunk] + mock_chunk2 = MagicMock() + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions + + mock_message.chunks = [mock_chunk1, mock_chunk2] # Act await processor.store_document_embeddings(mock_message) @@ -485,28 +490,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_client.return_value = MagicMock() mock_parser = MagicMock() - + # Act with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args: Processor.add_args(mock_parser) # Assert mock_parent_add_args.assert_called_once_with(mock_parser) - + # Verify processor-specific arguments were added assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client): - """Test proper UTF-8 decoding of chunk text""" + async def test_chunk_id_with_special_characters(self, mock_uuid, mock_qdrant_client): + """Test storing chunk_id with special characters (URIs)""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -517,65 +522,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) # Add collection to known_collections (simulates config push) - processor.known_collections[('utf8_user', 'utf8_collection')] = {} + processor.known_collections[('uri_user', 'uri_collection')] = {} - # Create mock message with UTF-8 encoded text + # Create mock message with URI-style chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'utf8_user' - mock_message.metadata.collection = 'utf8_collection' - + mock_message.metadata.user = 'uri_user' + mock_message.metadata.collection = 'uri_collection' + mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé' - mock_chunk.vectors = [[0.1, 0.2]] - + mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3' + mock_chunk.vector = [0.1, 0.2] + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) # Assert - # Verify chunk.decode was called with 'utf-8' - mock_chunk.chunk.decode.assert_called_with('utf-8') - - # Verify the decoded text was stored in payload + # Verify the chunk_id was stored correctly upsert_call_args = mock_qdrant_instance.upsert.call_args point = upsert_call_args[1]['points'][0] - assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé' - - @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - async def test_chunk_decode_exception_handling(self, mock_qdrant_client): - """Test handling of chunk decode exceptions""" - # Arrange - mock_qdrant_instance = MagicMock() - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-doc-qdrant-processor' - } - - processor = Processor(**config) - - # Add collection to known_collections (simulates config push) - processor.known_collections[('decode_user', 'decode_collection')] = {} - - # Create mock message with decode error - mock_message = MagicMock() - mock_message.metadata.user = 'decode_user' - mock_message.metadata.collection = 'decode_collection' - - mock_chunk = MagicMock() - mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte') - mock_chunk.vectors = [[0.1, 0.2]] - - mock_message.chunks = [mock_chunk] - - # Act & Assert - with pytest.raises(UnicodeDecodeError): - await processor.store_document_embeddings(mock_message) + assert point.payload['chunk_id'] == 'https://trustgraph.ai/doc/my-document/p1/c3' if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index 8a8e1090..e4d60adf 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: # Create test entities with embeddings entity1 = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity1'), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) entity2 = EntityEmbeddings( entity=Term(type=LITERAL, value='literal entity'), - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.entities = [entity1, entity2] @@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.entities = [entity] - + await processor.store_graph_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_entity - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once with the full vector + processor.vecstore.insert.assert_called_once() + actual_call = processor.vecstore.insert.call_args_list[0] + assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + assert actual_call[0][1] == 'http://example.com/entity' + assert actual_call[0][2] == 'test_user' + assert actual_call[0][3] == 'test_collection' @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" await processor.store_graph_embeddings(mock_message) - - # Verify insert was called for each vector of each entity with user/collection parameters + + # Verify insert was called once per entity with user/collection parameters expected_calls = [ - # Entity 1 vectors - ([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), - # Entity 2 vectors + # Entity 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), + # Entity 2 - single vector ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec @@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=LITERAL, value=''), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=LITERAL, value=None), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -172,26 +165,30 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + valid_entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/valid'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3], + chunk_id='' ) empty_entity = EntityEmbeddings( entity=Term(type=LITERAL, value=''), - vectors=[[0.4, 0.5, 0.6]] + vector=[0.4, 0.5, 0.6], + chunk_id='' ) none_entity = EntityEmbeddings( entity=Term(type=LITERAL, value=None), - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9], + chunk_id='' ) message.entities = [valid_entity, empty_entity, none_entity] - + await processor.store_graph_embeddings(message) - - # Verify only valid entity was inserted with user/collection parameters + + # Verify only valid entity was inserted with user/collection/chunk_id parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], 'http://example.com/valid', 'test_user', 'test_collection' + [0.1, 0.2, 0.3], 'http://example.com/valid', 'test_user', 'test_collection', + chunk_id='' ) @pytest.mark.asyncio @@ -218,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[] + vector=[] ) message.entities = [entity] @@ -234,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - entity = EntityEmbeddings( - entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + + # Each entity has a single vector of different dimensions + entity1 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity1'), + vector=[0.1, 0.2] # 2D vector ) - message.entities = [entity] - + entity2 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity2'), + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + entity3 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity3'), + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.entities = [entity1, entity2, entity3] + await processor.store_graph_embeddings(message) - + # Verify all vectors were inserted regardless of dimension expected_calls = [ - ([0.1, 0.2], 'http://example.com/entity'), - ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'), - ([0.7, 0.8, 0.9], 'http://example.com/entity'), + ([0.1, 0.2], 'http://example.com/entity1'), + ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'), + ([0.7, 0.8, 0.9], 'http://example.com/entity3'), ] - + assert processor.vecstore.insert.call_count == 3 for i, (expected_vec, expected_entity) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] @@ -270,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: uri_entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/uri_entity'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) literal_entity = EntityEmbeddings( entity=Term(type=LITERAL, value='literal entity text'), - vectors=[[0.4, 0.5, 0.6]] + vector=[0.4, 0.5, 0.6] ) message.entities = [uri_entity, literal_entity] diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 0fd0fde3..9ff53f4e 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor: message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - # Create test entity embeddings + # Create test entity embeddings (each entity has a single vector) entity1 = EntityEmbeddings( entity=Value(value="http://example.org/entity1", is_uri=True), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3] ) entity2 = EntityEmbeddings( - entity=Value(value="entity2", is_uri=False), - vectors=[[0.7, 0.8, 0.9]] + entity=Value(value="http://example.org/entity2", is_uri=True), + vector=[0.4, 0.5, 0.6] ) - message.entities = [entity1, entity2] + entity3 = EntityEmbeddings( + entity=Value(value="entity3", is_uri=False), + vector=[0.7, 0.8, 0.9] + ) + message.entities = [entity1, entity2, entity3] return message @@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Value(value="http://example.org/entity1", is_uri=True), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] - + # Mock index operations mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index processor.pinecone.has_index.return_value = True - - with patch('uuid.uuid4', side_effect=['id1', 'id2']): + + with patch('uuid.uuid4', side_effect=['id1']): await processor.store_graph_embeddings(message) - + # Verify index name and operations (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_with(expected_index_name) - - # Verify upsert was called for each vector - assert mock_index.upsert.call_count == 2 + + # Verify upsert was called for the single vector + assert mock_index.upsert.call_count == 1 # Check first vector upsert first_call = mock_index.upsert.call_args_list[0] @@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value=None, is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_different_vector_dimensions(self, processor): - """Test storing graph embeddings with different vector dimensions to same index""" + """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - entity = EntityEmbeddings( - entity=Value(value="test_entity", is_uri=False), - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + # Each entity has a single vector of different dimensions + entity1 = EntityEmbeddings( + entity=Value(value="entity1", is_uri=False), + vector=[0.1, 0.2] # 2D vector ) - message.entities = [entity] + entity2 = EntityEmbeddings( + entity=Value(value="entity2", is_uri=False), + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + entity3 = EntityEmbeddings( + entity=Value(value="entity3", is_uri=False), + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.entities = [entity1, entity2, entity3] - # All vectors now use the same index (no dimension in name) mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index processor.pinecone.has_index.return_value = True @@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[] + vector=[] ) message.entities = [entity] @@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 8b1a710a..3541ccd4 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity = MagicMock() mock_entity.entity.type = IRI mock_entity.entity.iri = 'test_entity' - mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions + mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions mock_message.entities = [mock_entity] @@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity1 = MagicMock() mock_entity1.entity.type = IRI mock_entity1.entity.iri = 'entity_one' - mock_entity1.vectors = [[0.1, 0.2]] + mock_entity1.vector = [0.1, 0.2] mock_entity2 = MagicMock() mock_entity2.entity.type = IRI mock_entity2.entity.iri = 'entity_two' - mock_entity2.vectors = [[0.3, 0.4]] + mock_entity2.vector = [0.3, 0.4] mock_message.entities = [mock_entity1, mock_entity2] @@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') - async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client): - """Test storing graph embeddings with multiple vectors per entity""" + async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client): + """Test storing graph embeddings with three entities""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value.return_value = 'test-uuid' - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('vector_user', 'vector_collection')] = {} - # Create mock message with entity having multiple vectors + # Create mock message with three entities mock_message = MagicMock() mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' - - mock_entity = MagicMock() - mock_entity.entity.type = IRI - mock_entity.entity.iri = 'multi_vector_entity' - mock_entity.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] - - mock_message.entities = [mock_entity] - + + mock_entity1 = MagicMock() + mock_entity1.entity.type = IRI + mock_entity1.entity.iri = 'entity_one' + mock_entity1.vector = [0.1, 0.2, 0.3] + + mock_entity2 = MagicMock() + mock_entity2.entity.type = IRI + mock_entity2.entity.iri = 'entity_two' + mock_entity2.vector = [0.4, 0.5, 0.6] + + mock_entity3 = MagicMock() + mock_entity3.entity.type = IRI + mock_entity3.entity.iri = 'entity_three' + mock_entity3.vector = [0.7, 0.8, 0.9] + + mock_message.entities = [mock_entity1, mock_entity2, mock_entity3] + # Act await processor.store_graph_embeddings(mock_message) # Assert - # Should be called 3 times (once per vector) + # Should be called 3 times (once per entity) assert mock_qdrant_instance.upsert.call_count == 3 - - # Verify all vectors were processed + + # Verify all entities were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - - expected_vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] + + expected_data = [ + ([0.1, 0.2, 0.3], 'entity_one'), + ([0.4, 0.5, 0.6], 'entity_two'), + ([0.7, 0.8, 0.9], 'entity_three') ] - + for i, call in enumerate(upsert_calls): point = call[1]['points'][0] - assert point.vector == expected_vectors[i] - assert point.payload['entity'] == 'multi_vector_entity' + assert point.vector == expected_data[i][0] + assert point.payload['entity'] == expected_data[i][1] @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client): @@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity_empty = MagicMock() mock_entity_empty.entity.type = LITERAL mock_entity_empty.entity.value = "" # Empty string - mock_entity_empty.vectors = [[0.1, 0.2]] + mock_entity_empty.vector = [0.1, 0.2] mock_entity_none = MagicMock() mock_entity_none.entity = None # None entity - mock_entity_none.vectors = [[0.3, 0.4]] + mock_entity_none.vector = [0.3, 0.4] mock_message.entities = [mock_entity_empty, mock_entity_none] diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index b4c5a5b4..e1c8f3b1 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='customer_id', index_value=['CUST001'], text='CUST001', - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) embeddings_msg = RowEmbeddings( @@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') - async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client): - """Test processing embeddings with multiple vectors""" + async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client): + """Test processing embeddings with a single vector""" from trustgraph.storage.row_embeddings.qdrant.write import Processor from trustgraph.schema import RowEmbeddings, RowIndexEmbedding @@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): metadata.collection = 'test_collection' metadata.id = 'doc-123' - # Embedding with multiple vectors + # Embedding with a single 6D vector embedding = RowIndexEmbedding( index_name='name', index_value=['John Doe'], text='John Doe', - vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) embeddings_msg = RowEmbeddings( @@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) - # Should be called 3 times (once per vector) - assert mock_qdrant_instance.upsert.call_count == 3 + # Should be called once for the single embedding + assert mock_qdrant_instance.upsert.call_count == 1 @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client): @@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='id', index_value=['123'], text='123', - vectors=[] # Empty vectors + vector=[] # Empty vector ) embeddings_msg = RowEmbeddings( @@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='id', index_value=['123'], text='123', - vectors=[[0.1, 0.2]] + vector=[0.1, 0.2] ) embeddings_msg = RowEmbeddings( diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index c8b81447..1976b844 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -190,7 +190,6 @@ class TestRowsCassandraStorageLogic: id="test-001", user="test_user", collection="test_collection", - metadata=[] ), schema_name="test_schema", values=[{"id": "123", "value": "test_data"}], @@ -252,7 +251,6 @@ class TestRowsCassandraStorageLogic: id="test-001", user="test_user", collection="test_collection", - metadata=[] ), schema_name="multi_index_schema", values=[{"id": "123", "category": "electronics", "status": "active"}], @@ -310,7 +308,6 @@ class TestRowsCassandraStorageBatchLogic: id="batch-001", user="test_user", collection="batch_collection", - metadata=[] ), schema_name="batch_schema", values=[ @@ -365,7 +362,6 @@ class TestRowsCassandraStorageBatchLogic: id="empty-001", user="test_user", collection="empty_collection", - metadata=[] ), schema_name="empty_schema", values=[], # Empty batch diff --git a/tests/unit/test_structured_data/__init__.py b/tests/unit/test_structured_data/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_structured_data/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py new file mode 100644 index 00000000..3222ec83 --- /dev/null +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -0,0 +1,296 @@ +""" +Tests for row embeddings query service: collection naming, query execution, +index filtering, result conversion, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import ( + RowEmbeddingsRequest, RowEmbeddingsResponse, + RowIndexMatch, Error, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(qdrant_client=None): + """Create a Processor without full FlowProcessor init.""" + from trustgraph.query.row_embeddings.qdrant.service import Processor + proc = Processor.__new__(Processor) + proc.qdrant = qdrant_client or MagicMock() + return proc + + +def _make_request(vector=None, user="test-user", collection="test-col", + schema_name="customers", limit=10, index_name=None): + return RowEmbeddingsRequest( + vector=vector or [0.1, 0.2, 0.3], + user=user, + collection=collection, + schema_name=schema_name, + limit=limit, + index_name=index_name or "", + ) + + +def _make_search_point(index_name, index_value, text, score): + point = MagicMock() + point.payload = { + "index_name": index_name, + "index_value": index_value, + "text": text, + } + point.score = score + return point + + +# --------------------------------------------------------------------------- +# sanitize_name +# --------------------------------------------------------------------------- + +class TestSanitizeName: + + def test_simple_name(self): + proc = _make_processor() + assert proc.sanitize_name("customers") == "customers" + + def test_special_chars_replaced(self): + proc = _make_processor() + assert proc.sanitize_name("my-schema.v2") == "my_schema_v2" + + def test_leading_digit_prefixed(self): + proc = _make_processor() + result = proc.sanitize_name("123schema") + assert result.startswith("r_") + assert "123schema" in result + + def test_uppercase_lowercased(self): + proc = _make_processor() + assert proc.sanitize_name("MySchema") == "myschema" + + def test_spaces_replaced(self): + proc = _make_processor() + assert proc.sanitize_name("my schema") == "my_schema" + + +# --------------------------------------------------------------------------- +# find_collection +# --------------------------------------------------------------------------- + +class TestFindCollection: + + def test_finds_matching_collection(self): + proc = _make_processor() + mock_coll = MagicMock() + mock_coll.name = "rows_test_user_test_col_customers_384" + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll] + proc.qdrant.get_collections.return_value = mock_collections + + result = proc.find_collection("test-user", "test-col", "customers") + + # Prefix: rows_test_user_test_col_customers_ + assert result == "rows_test_user_test_col_customers_384" + + def test_returns_none_when_no_match(self): + proc = _make_processor() + mock_coll = MagicMock() + mock_coll.name = "rows_other_user_other_col_schema_768" + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll] + proc.qdrant.get_collections.return_value = mock_collections + + result = proc.find_collection("test-user", "test-col", "customers") + assert result is None + + def test_returns_none_on_error(self): + proc = _make_processor() + proc.qdrant.get_collections.side_effect = Exception("connection error") + + result = proc.find_collection("user", "col", "schema") + assert result is None + + +# --------------------------------------------------------------------------- +# query_row_embeddings +# --------------------------------------------------------------------------- + +class TestQueryRowEmbeddings: + + @pytest.mark.asyncio + async def test_empty_vector_returns_empty(self): + proc = _make_processor() + request = _make_request(vector=[]) + + result = await proc.query_row_embeddings(request) + assert result == [] + + @pytest.mark.asyncio + async def test_no_collection_returns_empty(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value=None) + request = _make_request() + + result = await proc.query_row_embeddings(request) + assert result == [] + + @pytest.mark.asyncio + async def test_successful_query_returns_matches(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + points = [ + _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), + _make_search_point("address", ["123 Main St"], "123 Main St", 0.82), + ] + mock_result = MagicMock() + mock_result.points = points + proc.qdrant.query_points.return_value = mock_result + + request = _make_request() + result = await proc.query_row_embeddings(request) + + assert len(result) == 2 + assert isinstance(result[0], RowIndexMatch) + assert result[0].index_name == "name" + assert result[0].index_value == ["Alice Smith"] + assert result[0].score == 0.95 + assert result[1].index_name == "address" + + @pytest.mark.asyncio + async def test_index_name_filter_applied(self): + """When index_name is specified, a Qdrant filter should be used.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + mock_result = MagicMock() + mock_result.points = [] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request(index_name="address") + await proc.query_row_embeddings(request) + + call_kwargs = proc.qdrant.query_points.call_args[1] + assert call_kwargs["query_filter"] is not None + + @pytest.mark.asyncio + async def test_no_index_name_no_filter(self): + """When index_name is empty, no filter should be applied.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + mock_result = MagicMock() + mock_result.points = [] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request(index_name="") + await proc.query_row_embeddings(request) + + call_kwargs = proc.qdrant.query_points.call_args[1] + assert call_kwargs["query_filter"] is None + + @pytest.mark.asyncio + async def test_missing_payload_fields_default(self): + """Points with missing payload fields should use defaults.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + point = MagicMock() + point.payload = {} # Empty payload + point.score = 0.5 + + mock_result = MagicMock() + mock_result.points = [point] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request() + result = await proc.query_row_embeddings(request) + + assert len(result) == 1 + assert result[0].index_name == "" + assert result[0].index_value == [] + assert result[0].text == "" + + @pytest.mark.asyncio + async def test_qdrant_error_propagates(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.qdrant.query_points.side_effect = Exception("qdrant down") + + request = _make_request() + + with pytest.raises(Exception, match="qdrant down"): + await proc.query_row_embeddings(request) + + +# --------------------------------------------------------------------------- +# on_message handler +# --------------------------------------------------------------------------- + +class TestOnMessage: + + @pytest.mark.asyncio + async def test_successful_message_sends_response(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock(return_value=[ + RowIndexMatch(index_name="name", index_value=["Alice"], + text="Alice", score=0.9), + ]) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "req-1"} + + await proc.on_message(msg, MagicMock(), flow) + + sent = mock_pub.send.call_args[0][0] + assert isinstance(sent, RowEmbeddingsResponse) + assert sent.error is None + assert len(sent.matches) == 1 + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock( + side_effect=Exception("query failed") + ) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "req-2"} + + await proc.on_message(msg, MagicMock(), flow) + + sent = mock_pub.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "row-embeddings-query-error" + assert "query failed" in sent.error.message + assert sent.matches == [] + + @pytest.mark.asyncio + async def test_message_id_preserved(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock(return_value=[]) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "unique-42"} + + await proc.on_message(msg, MagicMock(), flow) + + props = mock_pub.send.call_args[1]["properties"] + assert props["id"] == "unique-42" diff --git a/tests/unit/test_structured_data/test_type_detector.py b/tests/unit/test_structured_data/test_type_detector.py new file mode 100644 index 00000000..7ce7060a --- /dev/null +++ b/tests/unit/test_structured_data/test_type_detector.py @@ -0,0 +1,235 @@ +""" +Tests for structured data type detection: CSV, JSON, XML format detection, +CSV option detection (delimiter, header), and helper functions. +""" + +import pytest + +from trustgraph.retrieval.structured_diag.type_detector import ( + detect_data_type, + _check_json_format, + _check_xml_format, + _check_csv_format, + _check_csv_with_delimiter, + detect_csv_options, + _is_numeric, +) + + +# --------------------------------------------------------------------------- +# detect_data_type (top-level dispatcher) +# --------------------------------------------------------------------------- + +class TestDetectDataType: + + def test_empty_string_returns_none(self): + detected, confidence = detect_data_type("") + assert detected is None + assert confidence == 0.0 + + def test_whitespace_only_returns_none(self): + detected, confidence = detect_data_type(" \n \t ") + assert detected is None + assert confidence == 0.0 + + def test_none_returns_none(self): + detected, confidence = detect_data_type(None) + assert detected is None + assert confidence == 0.0 + + def test_json_object_detected(self): + detected, confidence = detect_data_type('{"name": "Alice"}') + assert detected == "json" + assert confidence > 0.5 + + def test_json_array_detected(self): + detected, confidence = detect_data_type('[{"id": 1}, {"id": 2}]') + assert detected == "json" + assert confidence > 0.5 + + def test_xml_with_declaration_detected(self): + detected, confidence = detect_data_type('') + assert detected == "xml" + assert confidence > 0.5 + + def test_xml_without_declaration_detected(self): + detected, confidence = detect_data_type('val') + assert detected == "xml" + assert confidence > 0.5 + + def test_csv_detected(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + detected, confidence = detect_data_type(data) + assert detected == "csv" + assert confidence > 0.5 + + def test_plain_text_falls_through_to_csv(self): + """Non-JSON/XML text defaults to CSV detection.""" + detected, confidence = detect_data_type("just some text") + assert detected == "csv" + + +# --------------------------------------------------------------------------- +# _check_json_format +# --------------------------------------------------------------------------- + +class TestCheckJsonFormat: + + def test_valid_json_object(self): + assert _check_json_format('{"key": "value"}') > 0.9 + + def test_valid_json_array_of_objects(self): + assert _check_json_format('[{"id": 1}, {"id": 2}]') >= 0.9 + + def test_valid_json_array_of_primitives(self): + score = _check_json_format('[1, 2, 3]') + assert score > 0.5 + assert score < 0.9 # Lower confidence for non-object arrays + + def test_empty_json_object(self): + assert _check_json_format('{}') > 0.5 + + def test_invalid_json(self): + assert _check_json_format('{invalid json}') == 0.0 + + def test_non_json_starting_char(self): + assert _check_json_format('hello world') == 0.0 + + def test_empty_array(self): + score = _check_json_format('[]') + assert score > 0.0 # Parsed successfully but empty + + +# --------------------------------------------------------------------------- +# _check_xml_format +# --------------------------------------------------------------------------- + +class TestCheckXmlFormat: + + def test_valid_xml(self): + assert _check_xml_format('val') == 0.9 + + def test_xml_with_declaration(self): + xml = 'test' + assert _check_xml_format(xml) == 0.9 + + def test_malformed_xml(self): + score = _check_xml_format('') + # Has < and ') + # Starts with < but no closing tag + assert score <= 0.1 + + +# --------------------------------------------------------------------------- +# _check_csv_format and _check_csv_with_delimiter +# --------------------------------------------------------------------------- + +class TestCheckCsvFormat: + + def test_valid_csv_comma(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_semicolon(self): + data = "name;age;city\nAlice;30;NYC\nBob;25;LA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_tab(self): + data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_pipe(self): + data = "name|age|city\nAlice|30|NYC\nBob|25|LA" + assert _check_csv_format(data) > 0.7 + + def test_single_line_not_csv(self): + assert _check_csv_format("just one line") == 0.0 + + def test_single_column_not_csv(self): + data = "a\nb\nc" + assert _check_csv_with_delimiter(data, ",") == 0.0 + + def test_inconsistent_columns_low_score(self): + data = "a,b,c\n1,2\n3,4,5,6" + score = _check_csv_with_delimiter(data, ",") + assert score < 0.7 + + def test_many_rows_higher_score(self): + rows = ["name,age,city"] + [f"person{i},{20+i},city{i}" for i in range(20)] + data = "\n".join(rows) + score = _check_csv_format(data) + assert score > 0.8 + + +# --------------------------------------------------------------------------- +# detect_csv_options +# --------------------------------------------------------------------------- + +class TestDetectCsvOptions: + + def test_comma_delimiter_detected(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + options = detect_csv_options(data) + assert options["delimiter"] == "," + + def test_semicolon_delimiter_detected(self): + data = "name;age;city\nAlice;30;NYC\nBob;25;LA" + options = detect_csv_options(data) + assert options["delimiter"] == ";" + + def test_tab_delimiter_detected(self): + data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA" + options = detect_csv_options(data) + assert options["delimiter"] == "\t" + + def test_header_detected_when_first_row_text(self): + data = "name,age,salary\nAlice,30,50000\nBob,25,45000" + options = detect_csv_options(data) + assert options["has_header"] is True + + def test_no_header_when_all_numeric(self): + data = "1,2,3\n4,5,6\n7,8,9" + options = detect_csv_options(data) + assert options["has_header"] is False + + def test_single_line_returns_defaults(self): + options = detect_csv_options("just one line") + assert options["delimiter"] == "," + assert options["has_header"] is True + + def test_encoding_default(self): + data = "a,b\n1,2" + options = detect_csv_options(data) + assert options["encoding"] == "utf-8" + + +# --------------------------------------------------------------------------- +# _is_numeric helper +# --------------------------------------------------------------------------- + +class TestIsNumeric: + + def test_integer(self): + assert _is_numeric("42") is True + + def test_float(self): + assert _is_numeric("3.14") is True + + def test_negative(self): + assert _is_numeric("-10") is True + + def test_text(self): + assert _is_numeric("hello") is False + + def test_empty(self): + assert _is_numeric("") is False + + def test_whitespace_padded(self): + assert _is_numeric(" 42 ") is True diff --git a/tests/unit/test_text_completion/test_azure_openai_streaming.py b/tests/unit/test_text_completion/test_azure_openai_streaming.py new file mode 100644 index 00000000..b2f5a003 --- /dev/null +++ b/tests/unit/test_text_completion/test_azure_openai_streaming.py @@ -0,0 +1,182 @@ +""" +Tests for Azure OpenAI streaming: model/temperature override during streaming, +RateLimitError → TooManyRequests conversion, chunk iteration, and final token +count emission. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.model.text_completion.azure_openai.llm import Processor +from trustgraph.base import LlmChunk +from trustgraph.exceptions import TooManyRequests + + +def _make_processor(mock_azure_openai_class, model="gpt-4"): + """Create a Processor with mocked base classes.""" + with patch('trustgraph.base.async_processor.AsyncProcessor.__init__', + return_value=None), \ + patch('trustgraph.base.llm_service.LlmService.__init__', + return_value=None): + proc = Processor( + endpoint="https://test.openai.azure.com/", + token="test-token", + api_version="2024-12-01-preview", + model=model, + temperature=0.0, + max_output=4192, + concurrency=1, + taskgroup=AsyncMock(), + id="test-processor", + ) + return proc + + +def _make_stream_chunk(content=None, usage=None): + """Create a mock streaming chunk.""" + chunk = MagicMock() + if content: + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = content + else: + chunk.choices = [] + chunk.usage = usage + return chunk + + +class TestAzureOpenAIStreaming(IsolatedAsyncioTestCase): + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_yields_chunks(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + + stream_data = [ + _make_stream_chunk(content="Hello"), + _make_stream_chunk(content=" world"), + _make_stream_chunk(usage=usage), + ] + mock_client.chat.completions.create.return_value = iter(stream_data) + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + assert len(results) == 3 # 2 content + 1 final + assert results[0].text == "Hello" + assert results[0].is_final is False + assert results[1].text == " world" + assert results[2].is_final is True + assert results[2].in_token == 10 + assert results[2].out_token == 5 + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_model_override(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class, model="gpt-4") + + usage = MagicMock() + usage.prompt_tokens = 5 + usage.completion_tokens = 2 + + stream_data = [ + _make_stream_chunk(content="ok"), + _make_stream_chunk(usage=usage), + ] + mock_client.chat.completions.create.return_value = iter(stream_data) + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", model="gpt-4o" + ): + results.append(chunk) + + # All chunks carry overridden model + for r in results: + assert r.model == "gpt-4o" + + # Verify API call used overridden model + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "gpt-4o" + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_temperature_override(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 5 + usage.completion_tokens = 2 + + stream_data = [_make_stream_chunk(usage=usage)] + mock_client.chat.completions.create.return_value = iter(stream_data) + + async for _ in proc.generate_content_stream( + "sys", "user", temperature=0.7 + ): + pass + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["temperature"] == 0.7 + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_rate_limit_raises_too_many_requests(self, mock_azure_class): + from openai import RateLimitError + + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + mock_client.chat.completions.create.side_effect = RateLimitError( + "Rate limit exceeded", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_generic_exception_propagates(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + mock_client.chat.completions.create.side_effect = Exception("API down") + + with pytest.raises(Exception, match="API down"): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_passes_stream_options(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 0 + usage.completion_tokens = 0 + stream_data = [_make_stream_chunk(usage=usage)] + mock_client.chat.completions.create.return_value = iter(stream_data) + + async for _ in proc.generate_content_stream("sys", "user"): + pass + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream"] is True + assert call_kwargs["stream_options"] == {"include_usage": True} + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_supports_streaming(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + assert proc.supports_streaming() is True diff --git a/tests/unit/test_text_completion/test_azure_streaming.py b/tests/unit/test_text_completion/test_azure_streaming.py new file mode 100644 index 00000000..ff32e59d --- /dev/null +++ b/tests/unit/test_text_completion/test_azure_streaming.py @@ -0,0 +1,199 @@ +""" +Tests for Azure serverless endpoint streaming: model override during streaming, +HTTP 429 during streaming, SSE chunk parsing, and final token count emission. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.model.text_completion.azure.llm import Processor +from trustgraph.base import LlmChunk +from trustgraph.exceptions import TooManyRequests + + +def _make_processor(mock_requests, model="AzureAI", temperature=0.0): + """Create a Processor with mocked base classes.""" + with patch('trustgraph.base.async_processor.AsyncProcessor.__init__', + return_value=None), \ + patch('trustgraph.base.llm_service.LlmService.__init__', + return_value=None): + proc = Processor( + endpoint="https://test.azure.com/v1/chat/completions", + token="test-token", + temperature=temperature, + max_output=4192, + model=model, + concurrency=1, + taskgroup=AsyncMock(), + id="test-processor", + ) + return proc + + +def _sse_lines(*data_items): + """Build SSE byte lines from data items. '[DONE]' is appended.""" + lines = [] + for item in data_items: + if isinstance(item, dict): + lines.append(f"data: {json.dumps(item)}".encode()) + else: + lines.append(f"data: {item}".encode()) + lines.append(b"data: [DONE]") + return lines + + +class TestAzureServerlessStreaming(IsolatedAsyncioTestCase): + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_yields_chunks(self, mock_requests): + proc = _make_processor(mock_requests) + + chunks = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + {"choices": [{"delta": {"content": " world"}}]}, + {"usage": {"prompt_tokens": 10, "completion_tokens": 5}}, + ] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines(*chunks) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + # Content chunks + final chunk + assert len(results) == 3 + assert results[0].text == "Hello" + assert results[0].is_final is False + assert results[1].text == " world" + assert results[1].is_final is False + assert results[2].is_final is True + assert results[2].in_token == 10 + assert results[2].out_token == 5 + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_model_override(self, mock_requests): + proc = _make_processor(mock_requests, model="default-model") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"choices": [{"delta": {"content": "ok"}}]}, + {"usage": {"prompt_tokens": 5, "completion_tokens": 2}}, + ) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", model="override-model" + ): + results.append(chunk) + + # All chunks should carry the overridden model name + for r in results: + assert r.model == "override-model" + + # Verify the request body used the overridden model + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["model"] == "override-model" + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_temperature_override(self, mock_requests): + proc = _make_processor(mock_requests, temperature=0.0) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"choices": [{"delta": {"content": "ok"}}]}, + {"usage": {"prompt_tokens": 5, "completion_tokens": 2}}, + ) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", temperature=0.9 + ): + results.append(chunk) + + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["temperature"] == 0.9 + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_429_raises_too_many_requests(self, mock_requests): + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 429 + mock_requests.post.return_value = mock_response + + with pytest.raises(TooManyRequests): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_http_error_raises_runtime(self, mock_requests): + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "Service Unavailable" + mock_requests.post.return_value = mock_response + + with pytest.raises(RuntimeError, match="HTTP 503"): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_includes_stream_options(self, mock_requests): + """Verify stream=True and stream_options in request body.""" + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"usage": {"prompt_tokens": 0, "completion_tokens": 0}}, + ) + mock_requests.post.return_value = mock_response + + async for _ in proc.generate_content_stream("sys", "user"): + pass + + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["stream"] is True + assert body["stream_options"]["include_usage"] is True + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_malformed_json_skipped(self, mock_requests): + """Malformed JSON chunks should be skipped, not crash the stream.""" + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 200 + lines = [ + b"data: {not valid json}", + f'data: {json.dumps({"choices": [{"delta": {"content": "ok"}}]})}'.encode(), + f'data: {json.dumps({"usage": {"prompt_tokens": 1, "completion_tokens": 1}})}'.encode(), + b"data: [DONE]", + ] + mock_response.iter_lines.return_value = lines + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + # Should get the valid content chunk + final chunk + assert any(r.text == "ok" for r in results) + assert results[-1].is_final is True + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_supports_streaming_flag(self, mock_requests): + proc = _make_processor(mock_requests) + assert proc.supports_streaming() is True diff --git a/tests/unit/test_text_completion/test_rate_limit_contract.py b/tests/unit/test_text_completion/test_rate_limit_contract.py new file mode 100644 index 00000000..c9df217b --- /dev/null +++ b/tests/unit/test_text_completion/test_rate_limit_contract.py @@ -0,0 +1,140 @@ +""" +Cross-provider rate limit contract tests: verify that every LLM provider +that handles rate limits converts its provider-specific exception to +TooManyRequests consistently. + +Also tests the client-side error translation in the base client. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.exceptions import TooManyRequests + + +class TestAzureServerless429(IsolatedAsyncioTestCase): + """Azure serverless endpoint: HTTP 429 → TooManyRequests""" + + @patch('trustgraph.model.text_completion.azure.llm.requests') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_http_429_raises_too_many_requests(self, _llm, _async, mock_requests): + from trustgraph.model.text_completion.azure.llm import Processor + proc = Processor( + endpoint="https://test.azure.com/v1/chat", + token="t", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_response = MagicMock() + mock_response.status_code = 429 + mock_requests.post.return_value = mock_response + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestAzureOpenAIRateLimit(IsolatedAsyncioTestCase): + """Azure OpenAI: openai.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls): + from openai import RateLimitError + from trustgraph.model.text_completion.azure_openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + endpoint="https://test.openai.azure.com/", token="t", + model="gpt-4", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = RateLimitError( + "rate limited", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestOpenAIRateLimit(IsolatedAsyncioTestCase): + """OpenAI: openai.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls): + from openai import RateLimitError + from trustgraph.model.text_completion.openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = RateLimitError( + "rate limited", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestClaudeRateLimit(IsolatedAsyncioTestCase): + """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.claude.llm.anthropic') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_anthropic): + from trustgraph.model.text_completion.claude.llm import Processor + + mock_client = MagicMock() + mock_anthropic.Anthropic.return_value = mock_client + + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_anthropic.RateLimitError = type("RateLimitError", (Exception,), {}) + mock_client.messages.create.side_effect = mock_anthropic.RateLimitError( + "rate limited" + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestCohereRateLimit(IsolatedAsyncioTestCase): + """Cohere: cohere.TooManyRequestsError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.cohere.llm.cohere') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere): + from trustgraph.model.text_completion.cohere.llm import Processor + + mock_client = MagicMock() + mock_cohere.Client.return_value = mock_client + + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_cohere.TooManyRequestsError = type( + "TooManyRequestsError", (Exception,), {} + ) + mock_client.chat.side_effect = mock_cohere.TooManyRequestsError( + "rate limited" + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestClientSideRateLimitTranslation: + """Client base class: error type 'too-many-requests' → TooManyRequests""" + + def test_error_type_mapping(self): + """The wire format error type string is 'too-many-requests'.""" + from trustgraph.schema import Error + err = Error(type="too-many-requests", message="slow down") + assert err.type == "too-many-requests" diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index daa2cc5c..dc1405ac 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -59,7 +59,7 @@ from .flow import Flow, FlowInstance from .async_flow import AsyncFlow, AsyncFlowInstance # WebSocket clients -from .socket_client import SocketClient, SocketFlowInstance +from .socket_client import SocketClient, SocketFlowInstance, build_term from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance # Bulk operation clients @@ -70,6 +70,23 @@ from .async_bulk_client import AsyncBulkClient from .metrics import Metrics from .async_metrics import AsyncMetrics +# Explainability +from .explainability import ( + ExplainabilityClient, + ExplainEntity, + Question, + Grounding, + Exploration, + Focus, + Synthesis, + Reflection, + Analysis, + Conclusion, + EdgeSelection, + wire_triples_to_tuples, + extract_term_value, +) + # Types from .types import ( Triple, @@ -85,6 +102,7 @@ from .types import ( AgentObservation, AgentAnswer, RAGChunk, + ProvenanceEvent, ) # Exceptions @@ -124,6 +142,7 @@ __all__ = [ "SocketFlowInstance", "AsyncSocketClient", "AsyncSocketFlowInstance", + "build_term", # Bulk operation clients "BulkClient", @@ -133,6 +152,19 @@ __all__ = [ "Metrics", "AsyncMetrics", + # Explainability + "ExplainabilityClient", + "ExplainEntity", + "Question", + "Exploration", + "Focus", + "Synthesis", + "Analysis", + "Conclusion", + "EdgeSelection", + "wire_triples_to_tuples", + "extract_term_value", + # Types "Triple", "Uri", @@ -147,6 +179,7 @@ __all__ = [ "AgentObservation", "AgentAnswer", "RAGChunk", + "ProvenanceEvent", # Exceptions "ProtocolException", diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 440cebae..2ff37307 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -612,12 +612,12 @@ class AsyncFlowInstance: print(f"{entity['name']}: {entity['score']}") ``` """ - # First convert text to embeddings vectors - emb_result = await self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = await self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -626,20 +626,20 @@ class AsyncFlowInstance: return await self.request("graph-embeddings", request_data) - async def embeddings(self, text: str, **kwargs: Any): + async def embeddings(self, texts: list, **kwargs: Any): """ - Generate embeddings for input text. + Generate embeddings for input texts. - Converts text into a numerical vector representation using the flow's + Converts texts into numerical vector representations using the flow's configured embedding model. Useful for semantic search and similarity comparisons. Args: - text: Input text to embed + texts: List of input texts to embed **kwargs: Additional service-specific parameters Returns: - dict: Response containing embedding vector and metadata + dict: Response containing embedding vectors Example: ```python @@ -647,12 +647,12 @@ class AsyncFlowInstance: flow = async_flow.id("default") # Generate embeddings - result = await flow.embeddings(text="Sample text to embed") - vector = result.get("embedding") - print(f"Embedding dimension: {len(vector)}") + result = await flow.embeddings(texts=["Sample text to embed"]) + vectors = result.get("vectors") + print(f"Embedding dimension: {len(vectors[0][0])}") ``` """ - request_data = {"text": text} + request_data = {"texts": texts} request_data.update(kwargs) return await self.request("embeddings", request_data) @@ -810,12 +810,12 @@ class AsyncFlowInstance: print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") ``` """ - # First convert text to embeddings vectors - emb_result = await self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = await self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 3241e0f7..3279609b 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -110,15 +110,25 @@ class AsyncSocketClient: # Parse different chunk types chunk = self._parse_chunk(resp) - yield chunk + if chunk is not None: # Skip provenance messages in streaming + yield chunk - # Check if this is the final chunk - if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + # Check if this is the final message + # end_of_session indicates entire session is complete (including provenance) + # end_of_dialog is for agent dialogs + # complete is from the gateway envelope + if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break def _parse_chunk(self, resp: Dict[str, Any]): - """Parse response chunk into appropriate type""" + """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") + message_type = resp.get("message_type") + + # Handle new GraphRAG message format with message_type + if message_type == "provenance": + # Provenance messages are not yielded to user - they're metadata + return None if chunk_type == "thought": return AgentThought( @@ -143,7 +153,7 @@ class AsyncSocketClient: end_of_message=resp.get("end_of_message", False) ) else: - # RAG-style chunk (or generic chunk) + # RAG-style chunk (or generic chunk with message_type="chunk") # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( @@ -282,12 +292,12 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" - # First convert text to embeddings vectors - emb_result = await self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = await self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -296,9 +306,9 @@ class AsyncSocketFlowInstance: return await self.client._send_request("graph-embeddings", self.flow_id, request) - async def embeddings(self, text: str, **kwargs): + async def embeddings(self, texts: list, **kwargs): """Generate text embeddings""" - request = {"text": text} + request = {"texts": texts} request.update(kwargs) return await self.client._send_request("embeddings", self.flow_id, request) @@ -352,12 +362,12 @@ class AsyncSocketFlowInstance: limit: int = 10, **kwargs ): """Query row embeddings for semantic search on structured data""" - # First convert text to embeddings vectors - emb_result = await self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = await self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 3dfb0fba..75999550 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -322,8 +322,8 @@ class BulkClient: # Generate document embeddings to import def doc_embedding_generator(): - yield {"id": "doc1-chunk1", "embedding": [0.1, 0.2, ...]} - yield {"id": "doc1-chunk2", "embedding": [0.3, 0.4, ...]} + yield {"chunk_id": "doc1/p0/c0", "embedding": [0.1, 0.2, ...]} + yield {"chunk_id": "doc1/p0/c1", "embedding": [0.3, 0.4, ...]} # ... more embeddings bulk.import_document_embeddings( @@ -363,9 +363,9 @@ class BulkClient: # Export and process document embeddings for embedding in bulk.export_document_embeddings(flow="default"): - doc_id = embedding.get("id") + chunk_id = embedding.get("chunk_id") vector = embedding.get("embedding") - print(f"{doc_id}: {len(vector)} dimensions") + print(f"{chunk_id}: {len(vector)} dimensions") ``` """ async_gen = self._export_document_embeddings_async(flow) diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py new file mode 100644 index 00000000..1c986efb --- /dev/null +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -0,0 +1,1090 @@ +""" +Explainability support for TrustGraph API. + +Provides classes for explainability entities (Question, Exploration, Focus, +Synthesis, Analysis, Conclusion) and utilities for fetching them with +eventual consistency handling. +""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Tuple, Union + +# Provenance predicates +TG = "https://trustgraph.ai/ns/" +TG_QUERY = TG + "query" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_DOCUMENT = TG + "document" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" +TG_CHUNK_COUNT = TG + "chunkCount" +TG_SELECTED_CHUNK = TG + "selectedChunk" +TG_THOUGHT = TG + "thought" +TG_ACTION = TG + "action" +TG_ARGUMENTS = TG + "arguments" +TG_OBSERVATION = TG + "observation" + +# Entity types +TG_QUESTION = TG + "Question" +TG_GROUNDING = TG + "Grounding" +TG_EXPLORATION = TG + "Exploration" +TG_FOCUS = TG + "Focus" +TG_SYNTHESIS = TG + "Synthesis" +TG_ANALYSIS = TG + "Analysis" +TG_CONCLUSION = TG + "Conclusion" +TG_ANSWER_TYPE = TG + "Answer" +TG_REFLECTION_TYPE = TG + "Reflection" +TG_THOUGHT_TYPE = TG + "Thought" +TG_OBSERVATION_TYPE = TG + "Observation" +TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" +TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" +TG_AGENT_QUESTION = TG + "AgentQuestion" + +# PROV-O predicates +PROV = "http://www.w3.org/ns/prov#" +PROV_STARTED_AT_TIME = PROV + "startedAtTime" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" + +RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" +RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + +@dataclass +class EdgeSelection: + """A selected edge with reasoning from GraphRAG Focus step.""" + uri: str + edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...} + reasoning: str = "" + + +@dataclass +class ExplainEntity: + """Base class for explainability entities.""" + uri: str + entity_type: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "ExplainEntity": + """Parse triples into the appropriate entity type.""" + # Determine entity type from rdf:type triples + types = [o for s, p, o in triples if p == RDF_TYPE] + + if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types: + return Question.from_triples(uri, triples, types) + elif TG_GROUNDING in types: + return Grounding.from_triples(uri, triples) + elif TG_EXPLORATION in types: + return Exploration.from_triples(uri, triples) + elif TG_FOCUS in types: + return Focus.from_triples(uri, triples) + elif TG_SYNTHESIS in types: + return Synthesis.from_triples(uri, triples) + elif TG_REFLECTION_TYPE in types: + return Reflection.from_triples(uri, triples) + elif TG_ANALYSIS in types: + return Analysis.from_triples(uri, triples) + elif TG_CONCLUSION in types: + return Conclusion.from_triples(uri, triples) + else: + # Generic entity + return ExplainEntity(uri=uri, entity_type="unknown") + + +@dataclass +class Question(ExplainEntity): + """Question entity - the user's query that started the session.""" + query: str = "" + timestamp: str = "" + question_type: str = "" # "graph-rag", "document-rag", "agent" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]], + types: List[str]) -> "Question": + query = "" + timestamp = "" + question_type = "unknown" + + for s, p, o in triples: + if p == TG_QUERY: + query = o + elif p == PROV_STARTED_AT_TIME: + timestamp = o + + if TG_GRAPH_RAG_QUESTION in types: + question_type = "graph-rag" + elif TG_DOC_RAG_QUESTION in types: + question_type = "document-rag" + elif TG_AGENT_QUESTION in types: + question_type = "agent" + + return cls( + uri=uri, + entity_type="question", + query=query, + timestamp=timestamp, + question_type=question_type + ) + + +@dataclass +class Grounding(ExplainEntity): + """Grounding entity - concept decomposition of the query.""" + concepts: List[str] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Grounding": + concepts = [] + + for s, p, o in triples: + if p == TG_CONCEPT: + concepts.append(o) + + return cls( + uri=uri, + entity_type="grounding", + concepts=concepts + ) + + +@dataclass +class Exploration(ExplainEntity): + """Exploration entity - edges/chunks retrieved from the knowledge store.""" + edge_count: int = 0 + chunk_count: int = 0 + entities: List[str] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration": + edge_count = 0 + chunk_count = 0 + entities = [] + + for s, p, o in triples: + if p == TG_EDGE_COUNT: + try: + edge_count = int(o) + except (ValueError, TypeError): + pass + elif p == TG_CHUNK_COUNT: + try: + chunk_count = int(o) + except (ValueError, TypeError): + pass + elif p == TG_ENTITY: + entities.append(o) + + return cls( + uri=uri, + entity_type="exploration", + edge_count=edge_count, + chunk_count=chunk_count, + entities=entities + ) + + +@dataclass +class Focus(ExplainEntity): + """Focus entity - selected edges with LLM reasoning (GraphRAG only).""" + selected_edge_uris: List[str] = field(default_factory=list) + edge_selections: List[EdgeSelection] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Focus": + selected_edge_uris = [] + + for s, p, o in triples: + if p == TG_SELECTED_EDGE and isinstance(o, str): + selected_edge_uris.append(o) + + return cls( + uri=uri, + entity_type="focus", + selected_edge_uris=selected_edge_uris, + edge_selections=[] # Populated separately by fetching each edge URI + ) + + +@dataclass +class Synthesis(ExplainEntity): + """Synthesis entity - the final answer.""" + document: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis": + document = "" + + for s, p, o in triples: + if p == TG_DOCUMENT: + document = o + + return cls( + uri=uri, + entity_type="synthesis", + document=document + ) + + +@dataclass +class Reflection(ExplainEntity): + """Reflection entity - intermediate commentary (Thought or Observation).""" + document: str = "" + reflection_type: str = "" # "thought" or "observation" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Reflection": + document = "" + reflection_type = "" + + types = [o for s, p, o in triples if p == RDF_TYPE] + + if TG_THOUGHT_TYPE in types: + reflection_type = "thought" + elif TG_OBSERVATION_TYPE in types: + reflection_type = "observation" + + for s, p, o in triples: + if p == TG_DOCUMENT: + document = o + + return cls( + uri=uri, + entity_type="reflection", + document=document, + reflection_type=reflection_type + ) + + +@dataclass +class Analysis(ExplainEntity): + """Analysis entity - one think/act/observe cycle (Agent only).""" + action: str = "" + arguments: str = "" # JSON string + thought: str = "" + observation: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis": + action = "" + arguments = "" + thought = "" + observation = "" + + for s, p, o in triples: + if p == TG_ACTION: + action = o + elif p == TG_ARGUMENTS: + arguments = o + elif p == TG_THOUGHT: + thought = o + elif p == TG_OBSERVATION: + observation = o + + return cls( + uri=uri, + entity_type="analysis", + action=action, + arguments=arguments, + thought=thought, + observation=observation + ) + + +@dataclass +class Conclusion(ExplainEntity): + """Conclusion entity - final answer (Agent only).""" + document: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion": + document = "" + + for s, p, o in triples: + if p == TG_DOCUMENT: + document = o + + return cls( + uri=uri, + entity_type="conclusion", + document=document + ) + + +def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSelection: + """Parse triples for an edge selection entity.""" + uri = triples[0][0] if triples else "" + edge = None + reasoning = "" + + for s, p, o in triples: + if p == TG_EDGE and isinstance(o, dict): + edge = o + elif p == TG_REASONING: + reasoning = o + + return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning) + + +def extract_term_value(term: Dict[str, Any]) -> Any: + """Extract value from a wire-format Term dict.""" + t = term.get("t") or term.get("type") + + if t == "i": + return term.get("i") or term.get("iri", "") + elif t == "l": + return term.get("v") or term.get("value", "") + elif t == "t": + # Quoted triple - return as dict + tr = term.get("tr") or term.get("triple", {}) + return { + "s": extract_term_value(tr.get("s", {})), + "p": extract_term_value(tr.get("p", {})), + "o": extract_term_value(tr.get("o", {})), + } + else: + # Unknown format, try common keys + return term.get("i") or term.get("v") or term.get("iri") or term.get("value") or str(term) + + +def wire_triples_to_tuples(wire_triples: List[Dict[str, Any]]) -> List[Tuple[str, str, Any]]: + """Convert wire-format triples to (s, p, o) tuples.""" + result = [] + for t in wire_triples: + s = extract_term_value(t.get("s", {})) + p = extract_term_value(t.get("p", {})) + o = extract_term_value(t.get("o", {})) + result.append((s, p, o)) + return result + + +class ExplainabilityClient: + """ + Client for fetching explainability entities with eventual consistency handling. + + Uses quiescence detection: fetch, wait, fetch again, compare. + If results are the same, data is stable. + """ + + def __init__(self, flow_instance, retry_delay: float = 0.2, max_retries: int = 10): + """ + Initialize explainability client. + + Args: + flow_instance: A SocketFlowInstance for querying triples + retry_delay: Delay between retries in seconds (default: 0.2) + max_retries: Maximum retry attempts (default: 10) + """ + self.flow = flow_instance + self.retry_delay = retry_delay + self.max_retries = max_retries + self._label_cache: Dict[str, str] = {} + + def fetch_entity( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[ExplainEntity]: + """ + Fetch an explainability entity by URI with eventual consistency handling. + + Uses quiescence detection: + 1. Fetch triples for URI + 2. If zero results, retry + 3. If non-zero results, wait and fetch again + 4. If same results, data is stable - parse and return + 5. If different results, data still being written - retry + + Args: + uri: The entity URI to fetch + graph: Named graph to query (e.g., "urn:graph:retrieval") + user: User/keyspace identifier + collection: Collection identifier + + Returns: + ExplainEntity subclass or None if not found + """ + prev_triples = None + + for attempt in range(self.max_retries): + # Fetch triples for this URI + wire_triples = self.flow.triples_query( + s=uri, + g=graph, + user=user, + collection=collection, + limit=100 + ) + + if not wire_triples: + # Zero results - definitely retry + time.sleep(self.retry_delay) + continue + + # Convert to comparable format + triples = wire_triples_to_tuples(wire_triples) + triples_set = frozenset((s, p, str(o)) for s, p, o in triples) + + if prev_triples is None: + # First non-empty result - wait and check for stability + prev_triples = triples_set + time.sleep(self.retry_delay) + continue + + if triples_set == prev_triples: + # Same as before - data is stable + return ExplainEntity.from_triples(uri, triples) + else: + # Different - still being written, update and retry + prev_triples = triples_set + time.sleep(self.retry_delay) + continue + + # Max retries reached - return what we have if anything + if prev_triples: + # Re-fetch and parse + wire_triples = self.flow.triples_query( + s=uri, g=graph, user=user, collection=collection, limit=100 + ) + if wire_triples: + triples = wire_triples_to_tuples(wire_triples) + return ExplainEntity.from_triples(uri, triples) + + return None + + def fetch_edge_selection( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[EdgeSelection]: + """ + Fetch an edge selection entity (used by Focus). + + Args: + uri: The edge selection URI + graph: Named graph to query + user: User/keyspace identifier + collection: Collection identifier + + Returns: + EdgeSelection or None if not found + """ + wire_triples = self.flow.triples_query( + s=uri, + g=graph, + user=user, + collection=collection, + limit=100 + ) + + if not wire_triples: + return None + + triples = wire_triples_to_tuples(wire_triples) + return parse_edge_selection_triples(triples) + + def fetch_focus_with_edges( + self, + uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Optional[Focus]: + """ + Fetch a Focus entity and all its edge selections. + + Args: + uri: The Focus entity URI + graph: Named graph to query + user: User/keyspace identifier + collection: Collection identifier + + Returns: + Focus with populated edge_selections, or None + """ + entity = self.fetch_entity(uri, graph, user, collection) + + if not isinstance(entity, Focus): + return None + + # Fetch each edge selection + for edge_uri in entity.selected_edge_uris: + edge_sel = self.fetch_edge_selection(edge_uri, graph, user, collection) + if edge_sel: + entity.edge_selections.append(edge_sel) + + return entity + + def resolve_label( + self, + uri: str, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> str: + """ + Resolve rdfs:label for a URI, with caching. + + Args: + uri: The URI to get label for + user: User/keyspace identifier + collection: Collection identifier + + Returns: + The label if found, otherwise the URI itself + """ + if not uri or not uri.startswith(("http://", "https://", "urn:")): + return uri + + if uri in self._label_cache: + return self._label_cache[uri] + + wire_triples = self.flow.triples_query( + s=uri, + p=RDFS_LABEL, + user=user, + collection=collection, + limit=1 + ) + + if wire_triples: + triples = wire_triples_to_tuples(wire_triples) + if triples: + label = triples[0][2] + self._label_cache[uri] = label + return label + + self._label_cache[uri] = uri + return uri + + def resolve_edge_labels( + self, + edge: Dict[str, str], + user: Optional[str] = None, + collection: Optional[str] = None + ) -> Tuple[str, str, str]: + """ + Resolve labels for all components of an edge triple. + + Args: + edge: Dict with "s", "p", "o" keys + user: User/keyspace identifier + collection: Collection identifier + + Returns: + Tuple of (s_label, p_label, o_label) + """ + s_label = self.resolve_label(edge.get("s", ""), user, collection) + p_label = self.resolve_label(edge.get("p", ""), user, collection) + o_label = self.resolve_label(edge.get("o", ""), user, collection) + return (s_label, p_label, o_label) + + def fetch_document_content( + self, + document_uri: str, + api: Any, + user: Optional[str] = None, + max_content: int = 10000 + ) -> str: + """ + Fetch content from the librarian by document URI. + + Args: + document_uri: The document URI in the librarian + api: TrustGraph Api instance for librarian access + user: User identifier for librarian + max_content: Maximum content length to return + + Returns: + The document content as a string + """ + if not document_uri: + return "" + + doc_id = document_uri + + # Retry fetching from librarian for eventual consistency + for attempt in range(self.max_retries): + try: + library = api.library() + content_bytes = library.get_document_content(user=user, id=doc_id) + + # Decode as text + try: + content = content_bytes.decode('utf-8') + if len(content) > max_content: + return content[:max_content] + "... [truncated]" + return content + except UnicodeDecodeError: + return f"[Binary: {len(content_bytes)} bytes]" + + except Exception as e: + if attempt < self.max_retries - 1: + time.sleep(self.retry_delay) + continue + return f"[Error fetching content: {e}]" + + return "" + + + def fetch_graphrag_trace( + self, + question_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete GraphRAG trace starting from a question URI. + + Follows the provenance chain: Question -> Grounding -> Exploration -> Focus -> Synthesis + + Args: + question_uri: The question entity URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for synthesis + + Returns: + Dict with question, grounding, exploration, focus, synthesis entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "grounding": None, + "exploration": None, + "focus": None, + "synthesis": None, + } + + # Fetch question + question = self.fetch_entity(question_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Find grounding: ?grounding prov:wasGeneratedBy question_uri + grounding_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=question_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if grounding_triples: + grounding_uris = [ + extract_term_value(t.get("s", {})) + for t in grounding_triples + ] + for gnd_uri in grounding_uris: + grounding = self.fetch_entity(gnd_uri, graph, user, collection) + if isinstance(grounding, Grounding): + trace["grounding"] = grounding + break + + if not trace["grounding"]: + return trace + + # Find exploration: ?exploration prov:wasDerivedFrom grounding_uri + exploration_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["grounding"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if exploration_triples: + exploration_uris = [ + extract_term_value(t.get("s", {})) + for t in exploration_triples + ] + for exp_uri in exploration_uris: + exploration = self.fetch_entity(exp_uri, graph, user, collection) + if isinstance(exploration, Exploration): + trace["exploration"] = exploration + break + + if not trace["exploration"]: + return trace + + # Find focus: ?focus prov:wasDerivedFrom exploration_uri + focus_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["exploration"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if focus_triples: + focus_uris = [ + extract_term_value(t.get("s", {})) + for t in focus_triples + ] + for focus_uri in focus_uris: + focus = self.fetch_focus_with_edges(focus_uri, graph, user, collection) + if focus: + trace["focus"] = focus + break + + if not trace["focus"]: + return trace + + # Find synthesis: ?synthesis prov:wasDerivedFrom focus_uri + synthesis_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["focus"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if synthesis_triples: + synthesis_uris = [ + extract_term_value(t.get("s", {})) + for t in synthesis_triples + ] + for synth_uri in synthesis_uris: + synthesis = self.fetch_entity(synth_uri, graph, user, collection) + if isinstance(synthesis, Synthesis): + trace["synthesis"] = synthesis + break + + return trace + + def fetch_docrag_trace( + self, + question_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete DocumentRAG trace starting from a question URI. + + Follows the provenance chain: + Question -> Grounding -> Exploration -> Synthesis + + Args: + question_uri: The question entity URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for synthesis + + Returns: + Dict with question, grounding, exploration, synthesis entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "grounding": None, + "exploration": None, + "synthesis": None, + } + + # Fetch question + question = self.fetch_entity(question_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Find grounding: ?grounding prov:wasGeneratedBy question_uri + grounding_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=question_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if grounding_triples: + grounding_uris = [ + extract_term_value(t.get("s", {})) + for t in grounding_triples + ] + for gnd_uri in grounding_uris: + grounding = self.fetch_entity(gnd_uri, graph, user, collection) + if isinstance(grounding, Grounding): + trace["grounding"] = grounding + break + + if not trace["grounding"]: + return trace + + # Find exploration: ?exploration prov:wasDerivedFrom grounding_uri + exploration_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["grounding"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if exploration_triples: + exploration_uris = [ + extract_term_value(t.get("s", {})) + for t in exploration_triples + ] + for exp_uri in exploration_uris: + exploration = self.fetch_entity(exp_uri, graph, user, collection) + if isinstance(exploration, Exploration): + trace["exploration"] = exploration + break + + if not trace["exploration"]: + return trace + + # Find synthesis: ?synthesis prov:wasDerivedFrom exploration_uri + synthesis_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["exploration"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if synthesis_triples: + synthesis_uris = [ + extract_term_value(t.get("s", {})) + for t in synthesis_triples + ] + for synth_uri in synthesis_uris: + synthesis = self.fetch_entity(synth_uri, graph, user, collection) + if isinstance(synthesis, Synthesis): + trace["synthesis"] = synthesis + break + + return trace + + def fetch_agent_trace( + self, + session_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + api: Any = None, + max_content: int = 10000 + ) -> Dict[str, Any]: + """ + Fetch the complete Agent trace starting from a session URI. + + Follows the provenance chain: Question -> Analysis(s) -> Conclusion + + Args: + session_uri: The agent session/question URI + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + api: TrustGraph Api instance for librarian access (optional) + max_content: Maximum content length for conclusion + + Returns: + Dict with question, iterations (Analysis list), conclusion entities + """ + if graph is None: + graph = "urn:graph:retrieval" + + trace = { + "question": None, + "iterations": [], + "conclusion": None, + } + + # Fetch question/session + question = self.fetch_entity(session_uri, graph, user, collection) + if not isinstance(question, Question): + return trace + trace["question"] = question + + # Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after + current_uri = session_uri + is_first = True + max_iterations = 50 # Safety limit + + for _ in range(max_iterations): + # First hop uses wasGeneratedBy (entity←activity), + # subsequent hops use wasDerivedFrom (entity←entity) + if is_first: + derived_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + # Fall back to wasDerivedFrom for backwards compatibility + if not derived_triples: + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + is_first = False + else: + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + + if not derived_triples: + break + + derived_uri = extract_term_value(derived_triples[0].get("s", {})) + if not derived_uri: + break + + entity = self.fetch_entity(derived_uri, graph, user, collection) + + if isinstance(entity, Analysis): + trace["iterations"].append(entity) + current_uri = derived_uri + elif isinstance(entity, Conclusion): + trace["conclusion"] = entity + break + else: + # Unknown entity type, stop + break + + return trace + + def list_sessions( + self, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + limit: int = 50 + ) -> List[Question]: + """ + List all explainability sessions (questions) in a collection. + + Args: + graph: Named graph (default: urn:graph:retrieval) + user: User/keyspace identifier + collection: Collection identifier + limit: Maximum number of sessions to return + + Returns: + List of Question entities sorted by timestamp (newest first) + """ + if graph is None: + graph = "urn:graph:retrieval" + + # Query for all triples with predicate = tg:query + query_triples = self.flow.triples_query( + p=TG_QUERY, + g=graph, + user=user, + collection=collection, + limit=limit + ) + + questions = [] + for t in query_triples: + question_uri = extract_term_value(t.get("s", {})) + if question_uri: + entity = self.fetch_entity(question_uri, graph, user, collection) + if isinstance(entity, Question): + questions.append(entity) + + # Sort by timestamp (newest first) + questions.sort(key=lambda q: q.timestamp or "", reverse=True) + + return questions + + def detect_session_type( + self, + session_uri: str, + graph: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None + ) -> str: + """ + Detect whether a session is GraphRAG or Agent type. + + Args: + session_uri: The session/question URI + graph: Named graph + user: User/keyspace identifier + collection: Collection identifier + + Returns: + "graphrag" or "agent" + """ + if graph is None: + graph = "urn:graph:retrieval" + + # Fast path: check URI pattern + if "agent" in session_uri: + return "agent" + if "question" in session_uri: + return "graphrag" + if "docrag" in session_uri: + return "docrag" + + # Check what's derived from this entity + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=session_uri, + g=graph, + user=user, + collection=collection, + limit=5 + ) + + generated_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=session_uri, + g=graph, + user=user, + collection=collection, + limit=5 + ) + + all_child_uris = [ + extract_term_value(t.get("s", {})) + for t in (derived_triples + generated_triples) + ] + + for child_uri in all_child_uris: + entity = self.fetch_entity(child_uri, graph, user, collection) + if isinstance(entity, Analysis): + return "agent" + if isinstance(entity, Exploration): + return "graphrag" + + return "graphrag" # Default diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index cc07f794..f20d4d56 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -9,26 +9,45 @@ including LLM operations, RAG queries, knowledge graph management, and more. import json import base64 -from .. knowledge import hash, Uri, Literal -from .. schema import IRI, LITERAL +from .. knowledge import hash, Uri, Literal, QuotedTriple +from .. schema import IRI, LITERAL, TRIPLE from . types import Triple from . exceptions import ProtocolException def to_value(x): - """Convert wire format to Uri or Literal.""" + """Convert wire format to Uri, Literal, or QuotedTriple.""" if x.get("t") == IRI: return Uri(x.get("i", "")) elif x.get("t") == LITERAL: return Literal(x.get("v", "")) + elif x.get("t") == TRIPLE: + # Wire format uses "tr" key for nested triple dict + triple_data = x.get("tr") + if triple_data: + return QuotedTriple( + s=to_value(triple_data.get("s", {})), + p=to_value(triple_data.get("p", {})), + o=to_value(triple_data.get("o", {})), + ) + return Literal("") # Fallback for any other type return Literal(x.get("v", x.get("i", ""))) def from_value(v): - """Convert Uri or Literal to wire format.""" + """Convert Uri, Literal, or QuotedTriple to wire format.""" if isinstance(v, Uri): return {"t": IRI, "i": str(v)} + elif isinstance(v, QuotedTriple): + return { + "t": TRIPLE, + "tr": { + "s": from_value(v.s), + "p": from_value(v.p), + "o": from_value(v.o), + } + } else: return {"t": LITERAL, "v": str(v)} @@ -525,30 +544,29 @@ class FlowInstance: input )["response"] - def embeddings(self, text): + def embeddings(self, texts): """ - Generate vector embeddings for text. + Generate vector embeddings for one or more texts. - Converts text into dense vector representations suitable for semantic + Converts texts into dense vector representations suitable for semantic search and similarity comparison. Args: - text: Input text to embed + texts: List of input texts to embed Returns: - list[float]: Vector embedding + list[list[list[float]]]: Vector embeddings, one set per input text Example: ```python flow = api.flow().id("default") - vectors = flow.embeddings("quantum computing") - print(f"Embedding dimension: {len(vectors)}") + vectors = flow.embeddings(["quantum computing"]) + print(f"Embedding dimension: {len(vectors[0][0])}") ``` """ - # The input consists of a text block input = { - "text": text + "texts": texts } return self.request( @@ -581,16 +599,17 @@ class FlowInstance: collection="scientists", limit=5 ) + # results contains {"entities": [{"entity": {...}, "score": 0.95}, ...]} ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] # Query graph embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -615,7 +634,7 @@ class FlowInstance: limit: Maximum number of results (default: 10) Returns: - dict: Query results with similar document chunks + dict: Query results with chunks containing chunk_id and score Example: ```python @@ -626,16 +645,17 @@ class FlowInstance: collection="research-papers", limit=5 ) + # results contains {"chunks": [{"chunk_id": "doc1/p0/c0", "score": 0.95}, ...]} ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] # Query document embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -1343,13 +1363,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] # Query row embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index 1fae350c..84f98918 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -9,17 +9,27 @@ into flows for use in queries and RAG operations. import json import base64 -from .. knowledge import hash, Uri, Literal -from .. schema import IRI, LITERAL +from .. knowledge import hash, Uri, Literal, QuotedTriple +from .. schema import IRI, LITERAL, TRIPLE from . types import Triple def to_value(x): - """Convert wire format to Uri or Literal.""" + """Convert wire format to Uri, Literal, or QuotedTriple.""" if x.get("t") == IRI: return Uri(x.get("i", "")) elif x.get("t") == LITERAL: return Literal(x.get("v", "")) + elif x.get("t") == TRIPLE: + # Wire format uses "tr" key for nested triple dict + triple_data = x.get("tr") + if triple_data: + return QuotedTriple( + s=to_value(triple_data.get("s", {})), + p=to_value(triple_data.get("p", {})), + o=to_value(triple_data.get("o", {})), + ) + return Literal("") # Fallback for any other type return Literal(x.get("v", x.get("i", ""))) diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index e50dc0aa..396d64e0 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -6,32 +6,59 @@ including document storage, metadata management, and processing workflow coordin """ import datetime +import math import time import base64 import logging from . types import DocumentMetadata, ProcessingMetadata, Triple -from .. knowledge import hash, Uri, Literal -from .. schema import IRI, LITERAL +from .. knowledge import hash, Uri, Literal, QuotedTriple +from .. schema import IRI, LITERAL, TRIPLE from . exceptions import * logger = logging.getLogger(__name__) +# Threshold for switching to chunked upload (2MB) +# Lower threshold provides progress feedback and resumability on slower connections +CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 + +# Default chunk size (5MB - S3 multipart minimum) +DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 + def to_value(x): - """Convert wire format to Uri or Literal.""" + """Convert wire format to Uri, Literal, or QuotedTriple.""" if x.get("t") == IRI: return Uri(x.get("i", "")) elif x.get("t") == LITERAL: return Literal(x.get("v", "")) + elif x.get("t") == TRIPLE: + # Wire format uses "tr" key for nested triple dict + triple_data = x.get("tr") + if triple_data: + return QuotedTriple( + s=to_value(triple_data.get("s", {})), + p=to_value(triple_data.get("p", {})), + o=to_value(triple_data.get("o", {})), + ) + return Literal("") # Fallback for any other type return Literal(x.get("v", x.get("i", ""))) def from_value(v): - """Convert Uri or Literal to wire format.""" + """Convert Uri, Literal, or QuotedTriple to wire format.""" if isinstance(v, Uri): return {"t": IRI, "i": str(v)} + elif isinstance(v, QuotedTriple): + return { + "t": TRIPLE, + "tr": { + "s": from_value(v.s), + "p": from_value(v.p), + "o": from_value(v.o), + } + } else: return {"t": LITERAL, "v": str(v)} @@ -67,13 +94,14 @@ class Library: def add_document( self, document, id, metadata, user, title, comments, - kind="text/plain", tags=[], + kind="text/plain", tags=[], on_progress=None, ): """ Add a document to the library. Stores a document with associated metadata in the library for - retrieval and processing. + retrieval and processing. For large documents (> 10MB), automatically + uses chunked upload for better reliability and progress tracking. Args: document: Document content as bytes @@ -84,6 +112,7 @@ class Library: comments: Document description or comments kind: MIME type of the document (default: "text/plain") tags: List of tags for categorization (default: []) + on_progress: Optional callback(bytes_sent, total_bytes) for progress updates Returns: dict: Response from the add operation @@ -107,6 +136,22 @@ class Library: kind="application/pdf", tags=["research", "physics"] ) + + # Add a large document with progress tracking + def progress(sent, total): + print(f"Uploaded {sent}/{total} bytes ({100*sent//total}%)") + + with open("large_document.pdf", "rb") as f: + library.add_document( + document=f.read(), + id="large-doc-001", + metadata=[], + user="trustgraph", + title="Large Document", + comments="A very large document", + kind="application/pdf", + on_progress=progress + ) ``` """ @@ -124,6 +169,21 @@ class Library: if not title: title = "" if not comments: comments = "" + # Check if we should use chunked upload + if len(document) >= CHUNKED_UPLOAD_THRESHOLD: + return self._add_document_chunked( + document=document, + id=id, + metadata=metadata, + user=user, + title=title, + comments=comments, + kind=kind, + tags=tags, + on_progress=on_progress, + ) + + # Small document: use single operation (existing behavior) triples = [] def emit(t): @@ -167,14 +227,111 @@ class Library: return self.request(input) - def get_documents(self, user): + def _add_document_chunked( + self, document, id, metadata, user, title, comments, + kind, tags, on_progress=None, + ): + """ + Add a large document using chunked upload. + + Internal method that handles multipart upload for large documents. + """ + total_size = len(document) + chunk_size = DEFAULT_CHUNK_SIZE + + logger.info(f"Starting chunked upload for document {id} ({total_size} bytes)") + + # Begin upload session + begin_request = { + "operation": "begin-upload", + "document-metadata": { + "id": id, + "time": int(time.time()), + "kind": kind, + "title": title, + "comments": comments, + "user": user, + "tags": tags, + }, + "total-size": total_size, + "chunk-size": chunk_size, + } + + begin_response = self.request(begin_request) + + upload_id = begin_response.get("upload-id") + if not upload_id: + raise RuntimeError("Failed to begin upload: no upload_id returned") + + actual_chunk_size = begin_response.get("chunk-size", chunk_size) + total_chunks = begin_response.get("total-chunks", math.ceil(total_size / actual_chunk_size)) + + logger.info(f"Upload session {upload_id} created, {total_chunks} chunks") + + try: + # Upload chunks + bytes_sent = 0 + for chunk_index in range(total_chunks): + start = chunk_index * actual_chunk_size + end = min(start + actual_chunk_size, total_size) + chunk_data = document[start:end] + + chunk_request = { + "operation": "upload-chunk", + "upload-id": upload_id, + "chunk-index": chunk_index, + "content": base64.b64encode(chunk_data).decode("utf-8"), + "user": user, + } + + chunk_response = self.request(chunk_request) + + bytes_sent = end + + # Call progress callback if provided + if on_progress: + on_progress(bytes_sent, total_size) + + logger.debug(f"Chunk {chunk_index + 1}/{total_chunks} uploaded") + + # Complete upload + complete_request = { + "operation": "complete-upload", + "upload-id": upload_id, + "user": user, + } + + complete_response = self.request(complete_request) + + logger.info(f"Chunked upload completed for document {id}") + + return complete_response + + except Exception as e: + # Try to abort on failure + logger.error(f"Chunked upload failed: {e}") + try: + abort_request = { + "operation": "abort-upload", + "upload-id": upload_id, + "user": user, + } + self.request(abort_request) + logger.info(f"Aborted failed upload {upload_id}") + except Exception as abort_error: + logger.warning(f"Failed to abort upload: {abort_error}") + raise + + def get_documents(self, user, include_children=False): """ List all documents for a user. Retrieves metadata for all documents owned by the specified user. + By default, only returns top-level documents (not child/extracted documents). Args: user: User identifier + include_children: If True, also include child documents (default: False) Returns: list[DocumentMetadata]: List of document metadata objects @@ -185,18 +342,24 @@ class Library: Example: ```python library = api.library() + + # Get only top-level documents docs = library.get_documents(user="trustgraph") for doc in docs: print(f"{doc.id}: {doc.title} ({doc.kind})") print(f" Uploaded: {doc.time}") print(f" Tags: {', '.join(doc.tags)}") + + # Get all documents including extracted pages + all_docs = library.get_documents(user="trustgraph", include_children=True) ``` """ input = { "operation": "list-documents", "user": user, + "include-children": include_children, } object = self.request(input) @@ -218,7 +381,9 @@ class Library: for w in v["metadata"] ], user = v["user"], - tags = v["tags"] + tags = v["tags"], + parent_id = v.get("parent-id", ""), + document_type = v.get("document-type", "source"), ) for v in object["document-metadatas"] ] @@ -261,7 +426,7 @@ class Library: doc = object["document-metadata"] try: - DocumentMetadata( + return DocumentMetadata( id = doc["id"], time = datetime.datetime.fromtimestamp(doc["time"]), kind = doc["kind"], @@ -276,7 +441,9 @@ class Library: for w in doc["metadata"] ], user = doc["user"], - tags = doc["tags"] + tags = doc["tags"], + parent_id = doc.get("parent-id", ""), + document_type = doc.get("document-type", "source"), ) except Exception as e: logger.error("Failed to parse document response", exc_info=True) @@ -535,3 +702,447 @@ class Library: logger.error("Failed to parse processing list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") + # Chunked upload management methods + + def get_pending_uploads(self, user): + """ + List all pending (in-progress) uploads for a user. + + Retrieves information about chunked uploads that have been started + but not yet completed. + + Args: + user: User identifier + + Returns: + list[dict]: List of pending upload information + + Example: + ```python + library = api.library() + pending = library.get_pending_uploads(user="trustgraph") + + for upload in pending: + print(f"Upload {upload['upload_id']}:") + print(f" Document: {upload['document_id']}") + print(f" Progress: {upload['chunks_received']}/{upload['total_chunks']}") + ``` + """ + input = { + "operation": "list-uploads", + "user": user, + } + + response = self.request(input) + + return response.get("upload-sessions", []) + + def get_upload_status(self, upload_id, user): + """ + Get the status of a specific upload. + + Retrieves detailed status information about a chunked upload, + including which chunks have been received and which are missing. + + Args: + upload_id: Upload session identifier + user: User identifier + + Returns: + dict: Upload status information including: + - upload_id: The upload session ID + - state: "in-progress", "completed", or "expired" + - chunks_received: Number of chunks received + - total_chunks: Total number of chunks expected + - received_chunks: List of received chunk indices + - missing_chunks: List of missing chunk indices + - bytes_received: Total bytes received + - total_bytes: Total expected bytes + + Example: + ```python + library = api.library() + status = library.get_upload_status( + upload_id="abc-123", + user="trustgraph" + ) + + if status['state'] == 'in-progress': + print(f"Missing chunks: {status['missing_chunks']}") + ``` + """ + input = { + "operation": "get-upload-status", + "upload-id": upload_id, + "user": user, + } + + return self.request(input) + + def abort_upload(self, upload_id, user): + """ + Abort an in-progress upload. + + Cancels a chunked upload and cleans up any uploaded chunks. + + Args: + upload_id: Upload session identifier + user: User identifier + + Returns: + dict: Empty response on success + + Example: + ```python + library = api.library() + library.abort_upload(upload_id="abc-123", user="trustgraph") + ``` + """ + input = { + "operation": "abort-upload", + "upload-id": upload_id, + "user": user, + } + + return self.request(input) + + def resume_upload(self, upload_id, document, user, on_progress=None): + """ + Resume an interrupted upload. + + Continues a chunked upload that was previously interrupted, + uploading only the missing chunks. + + Args: + upload_id: Upload session identifier to resume + document: Complete document content as bytes + user: User identifier + on_progress: Optional callback(bytes_sent, total_bytes) for progress updates + + Returns: + dict: Response from completing the upload + + Example: + ```python + library = api.library() + + # Check what's missing + status = library.get_upload_status( + upload_id="abc-123", + user="trustgraph" + ) + + if status['state'] == 'in-progress': + # Resume with the same document + with open("large_document.pdf", "rb") as f: + library.resume_upload( + upload_id="abc-123", + document=f.read(), + user="trustgraph" + ) + ``` + """ + # Get current status + status = self.get_upload_status(upload_id, user) + + if status.get("upload-state") == "expired": + raise RuntimeError("Upload session has expired, please start a new upload") + + if status.get("upload-state") == "completed": + return {"message": "Upload already completed"} + + missing_chunks = status.get("missing-chunks", []) + total_chunks = status.get("total-chunks", 0) + total_bytes = status.get("total-bytes", len(document)) + chunk_size = total_bytes // total_chunks if total_chunks > 0 else DEFAULT_CHUNK_SIZE + + logger.info(f"Resuming upload {upload_id}, {len(missing_chunks)} chunks remaining") + + # Upload missing chunks + for chunk_index in missing_chunks: + start = chunk_index * chunk_size + end = min(start + chunk_size, len(document)) + chunk_data = document[start:end] + + chunk_request = { + "operation": "upload-chunk", + "upload-id": upload_id, + "chunk-index": chunk_index, + "content": base64.b64encode(chunk_data).decode("utf-8"), + "user": user, + } + + self.request(chunk_request) + + if on_progress: + # Estimate progress including previously uploaded chunks + uploaded = total_chunks - len(missing_chunks) + missing_chunks.index(chunk_index) + 1 + bytes_sent = min(uploaded * chunk_size, total_bytes) + on_progress(bytes_sent, total_bytes) + + logger.debug(f"Resumed chunk {chunk_index}") + + # Complete upload + complete_request = { + "operation": "complete-upload", + "upload-id": upload_id, + "user": user, + } + + return self.request(complete_request) + + # Child document methods + + def add_child_document( + self, document, id, parent_id, user, title, comments, + kind="text/plain", tags=[], metadata=None, + ): + """ + Add a child document linked to a parent document. + + Child documents are typically extracted content (e.g., pages from a PDF). + They are automatically marked with document_type="extracted" and linked + to their parent via parent_id. + + Args: + document: Document content as bytes + id: Document identifier (auto-generated if None) + parent_id: Parent document identifier (required) + user: User/owner identifier + title: Document title + comments: Document description or comments + kind: MIME type of the document (default: "text/plain") + tags: List of tags for categorization (default: []) + metadata: Optional metadata as list of Triple objects + + Returns: + dict: Response from the add operation + + Raises: + RuntimeError: If parent_id is not provided + + Example: + ```python + library = api.library() + + # Add extracted page from a PDF + library.add_child_document( + document=page_text.encode('utf-8'), + id="doc-123-page-1", + parent_id="doc-123", + user="trustgraph", + title="Page 1 of Research Paper", + comments="First page extracted from PDF", + kind="text/plain", + tags=["extracted", "page"] + ) + ``` + """ + if not parent_id: + raise RuntimeError("parent_id is required for child documents") + + if id is None: + id = hash(document) + + if not title: + title = "" + if not comments: + comments = "" + + triples = [] + if metadata: + if isinstance(metadata, list): + triples = [ + { + "s": from_value(t.s), + "p": from_value(t.p), + "o": from_value(t.o), + } + for t in metadata + ] + + input = { + "operation": "add-child-document", + "document-metadata": { + "id": id, + "time": int(time.time()), + "kind": kind, + "title": title, + "comments": comments, + "metadata": triples, + "user": user, + "tags": tags, + "parent-id": parent_id, + "document-type": "extracted", + }, + "content": base64.b64encode(document).decode("utf-8"), + } + + return self.request(input) + + def list_children(self, document_id, user): + """ + List all child documents for a given parent document. + + Args: + document_id: Parent document identifier + user: User identifier + + Returns: + list[DocumentMetadata]: List of child document metadata objects + + Example: + ```python + library = api.library() + children = library.list_children( + document_id="doc-123", + user="trustgraph" + ) + + for child in children: + print(f"{child.id}: {child.title}") + ``` + """ + input = { + "operation": "list-children", + "document-id": document_id, + "user": user, + } + + response = self.request(input) + + try: + return [ + DocumentMetadata( + id=v["id"], + time=datetime.datetime.fromtimestamp(v["time"]), + kind=v["kind"], + title=v["title"], + comments=v.get("comments", ""), + metadata=[ + Triple( + s=to_value(w["s"]), + p=to_value(w["p"]), + o=to_value(w["o"]) + ) + for w in v.get("metadata", []) + ], + user=v["user"], + tags=v.get("tags", []), + parent_id=v.get("parent-id", ""), + document_type=v.get("document-type", "source"), + ) + for v in response.get("document-metadatas", []) + ] + except Exception as e: + logger.error("Failed to parse children response", exc_info=True) + raise ProtocolException("Response not formatted correctly") + + def get_document_content(self, user, id): + """ + Get the content of a document. + + Retrieves the full content of a document as bytes. + + Args: + user: User identifier + id: Document identifier + + Returns: + bytes: Document content + + Example: + ```python + library = api.library() + content = library.get_document_content( + user="trustgraph", + id="doc-123" + ) + + # Write to file + with open("output.pdf", "wb") as f: + f.write(content) + ``` + """ + input = { + "operation": "get-document-content", + "user": user, + "document-id": id, + } + + response = self.request(input) + content_b64 = response.get("content", "") + + return base64.b64decode(content_b64) + + def stream_document_to_file(self, user, id, file_path, chunk_size=1024*1024, on_progress=None): + """ + Stream document content to a file. + + Downloads document content in chunks and writes directly to a file, + enabling memory-efficient handling of large documents. + + Args: + user: User identifier + id: Document identifier + file_path: Path to write the document content + chunk_size: Size of each chunk to download (default 1MB) + on_progress: Optional callback(bytes_received, total_bytes) for progress updates + + Returns: + int: Total bytes written + + Example: + ```python + library = api.library() + + def progress(received, total): + print(f"Downloaded {received}/{total} bytes") + + library.stream_document_to_file( + user="trustgraph", + id="large-doc-123", + file_path="/tmp/document.pdf", + on_progress=progress + ) + ``` + """ + chunk_index = 0 + total_bytes_written = 0 + total_bytes = None + + with open(file_path, "wb") as f: + while True: + input = { + "operation": "stream-document", + "user": user, + "document-id": id, + "chunk-index": chunk_index, + "chunk-size": chunk_size, + } + + response = self.request(input) + + content_b64 = response.get("content", "") + chunk_data = base64.b64decode(content_b64) + + if not chunk_data: + break + + f.write(chunk_data) + total_bytes_written += len(chunk_data) + + total_chunks = response.get("total-chunks", 1) + total_bytes = response.get("total-bytes", total_bytes_written) + + if on_progress: + on_progress(total_bytes_written, total_bytes) + + # Check if we've received all chunks + if chunk_index >= total_chunks - 1: + break + + chunk_index += 1 + + return total_bytes_written + diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index e8de442a..4e09351a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -11,10 +11,67 @@ import websockets from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent from . exceptions import ProtocolException, raise_from_error_dict +def build_term(value: Any, term_type: Optional[str] = None, + datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Build wire-format Term dict from a value. + + Auto-detection rules (when term_type is None): + - Already a dict with 't' key -> return as-is (already a Term) + - Starts with http://, https://, urn: -> IRI + - Wrapped in <> (e.g., ) -> IRI (angle brackets stripped) + - Anything else -> literal + + Args: + value: The term value (string, dict, or None) + term_type: One of 'iri', 'literal', or None for auto-detect + datatype: Datatype for literal objects (e.g., xsd:integer) + language: Language tag for literal objects (e.g., en) + + Returns: + dict: Wire-format Term dict, or None if value is None + """ + if value is None: + return None + + # If already a Term dict, return as-is + if isinstance(value, dict) and "t" in value: + return value + + # Convert to string for processing + value = str(value) + + # Auto-detect type if not specified + if term_type is None: + if value.startswith("<") and value.endswith(">") and not value.startswith("<<"): + # Angle-bracket wrapped IRI: + value = value[1:-1] # Strip < and > + term_type = "iri" + elif value.startswith(("http://", "https://", "urn:")): + term_type = "iri" + else: + term_type = "literal" + + if term_type == "iri": + # Strip angle brackets if present + if value.startswith("<") and value.endswith(">"): + value = value[1:-1] + return {"t": "i", "i": value} + elif term_type == "literal": + result = {"t": "l", "v": value} + if datatype: + result["dt"] = datatype + if language: + result["ln"] = language + return result + else: + raise ValueError(f"Unknown term type: {term_type}") + + class SocketClient: """ Synchronous WebSocket client for streaming operations. @@ -91,9 +148,19 @@ class SocketClient: service: str, flow: Optional[str], request: Dict[str, Any], - streaming: bool = False - ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: - """Synchronous wrapper around async WebSocket communication""" + streaming: bool = False, + streaming_raw: bool = False, + include_provenance: bool = False + ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: + """Synchronous wrapper around async WebSocket communication. + + Args: + service: Service name + flow: Flow ID (optional) + request: Request payload + streaming: Use parsed streaming (for agent/RAG chunk types) + streaming_raw: Use raw streaming (for data batches like triples) + """ # Create event loop if needed try: loop = asyncio.get_event_loop() @@ -105,12 +172,14 @@ class SocketClient: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - if streaming: - # For streaming, we need to return an iterator - # Create a generator that runs async code - return self._streaming_generator(service, flow, request, loop) + if streaming_raw: + # Raw streaming for data batches (triples, rows, etc.) + return self._streaming_generator_raw(service, flow, request, loop) + elif streaming: + # Parsed streaming for agent/RAG chunk types + return self._streaming_generator(service, flow, request, loop, include_provenance) else: - # For non-streaming, just run the async code and return result + # Non-streaming single response return loop.run_until_complete(self._send_request_async(service, flow, request)) def _streaming_generator( @@ -118,10 +187,11 @@ class SocketClient: service: str, flow: Optional[str], request: Dict[str, Any], - loop: asyncio.AbstractEventLoop + loop: asyncio.AbstractEventLoop, + include_provenance: bool = False ) -> Iterator[StreamingChunk]: - """Generator that yields streaming chunks""" - async_gen = self._send_request_async_streaming(service, flow, request) + """Generator that yields streaming chunks (for agent/RAG responses)""" + async_gen = self._send_request_async_streaming(service, flow, request, include_provenance) try: while True: @@ -137,6 +207,74 @@ class SocketClient: except: pass + def _streaming_generator_raw( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + loop: asyncio.AbstractEventLoop + ) -> Iterator[Dict[str, Any]]: + """Generator that yields raw response dicts (for data streaming like triples)""" + async_gen = self._send_request_async_streaming_raw(service, flow, request) + + try: + while True: + try: + data = loop.run_until_complete(async_gen.__anext__()) + yield data + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _send_request_async_streaming_raw( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any] + ) -> Iterator[Dict[str, Any]]: + """Async streaming that yields raw response dicts without parsing. + + Used for data streaming (triples, rows, etc.) where responses are + just batches of data, not agent/RAG chunk types. + """ + with self._lock: + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != request_id: + continue + + if "error" in response: + raise_from_error_dict(response["error"]) + + if "response" in response: + yield response["response"] + + if response.get("complete"): + break + async def _send_request_async( self, service: str, @@ -186,7 +324,8 @@ class SocketClient: self, service: str, flow: Optional[str], - request: Dict[str, Any] + request: Dict[str, Any], + include_provenance: bool = False ) -> Iterator[StreamingChunk]: """Async implementation of WebSocket request (streaming)""" # Generate unique request ID @@ -230,16 +369,41 @@ class SocketClient: raise_from_error_dict(resp["error"]) # Parse different chunk types - chunk = self._parse_chunk(resp) - yield chunk + chunk = self._parse_chunk(resp, include_provenance=include_provenance) + if chunk is not None: # Skip provenance messages unless include_provenance + yield chunk - # Check if this is the final chunk - if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + # Check if this is the final message + # end_of_session indicates entire session is complete (including provenance) + # end_of_dialog is for agent dialogs + # complete is from the gateway envelope + if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break - def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk: - """Parse response chunk into appropriate type""" + def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: + """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") + message_type = resp.get("message_type") + + # Handle GraphRAG/DocRAG message format with message_type + if message_type == "explain": + if include_provenance: + # Return provenance event for explainability + return ProvenanceEvent( + explain_id=resp.get("explain_id", ""), + explain_graph=resp.get("explain_graph", "") + ) + # Provenance messages are not yielded to user - they're metadata + return None + + # Handle Agent message format with chunk_type="explain" + if chunk_type == "explain": + if include_provenance: + return ProvenanceEvent( + explain_id=resp.get("explain_id", ""), + explain_graph=resp.get("explain_graph", "") + ) + return None if chunk_type == "thought": return AgentThought( @@ -281,7 +445,7 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) else: - # RAG-style chunk (or generic chunk) + # RAG-style chunk (or generic chunk with message_type="chunk") # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( @@ -385,6 +549,95 @@ class SocketFlowInstance: # regardless of streaming flag, so always use the streaming code path return self.client._send_request_sync("agent", self.flow_id, request, streaming=True) + def agent_explain( + self, + question: str, + user: str, + collection: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any + ) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]: + """ + Execute an agent operation with explainability support. + + Streams both content chunks (AgentThought, AgentObservation, AgentAnswer) + and provenance events (ProvenanceEvent). Provenance events contain URIs + that can be fetched using ExplainabilityClient to get detailed information + about the agent's reasoning process. + + Agent trace consists of: + - Session: The initial question and session metadata + - Iterations: Each thought/action/observation cycle + - Conclusion: The final answer + + Args: + question: User question or instruction + user: User identifier + collection: Collection identifier for provenance storage + state: Optional state dictionary for stateful conversations + group: Optional group identifier for multi-user contexts + history: Optional conversation history as list of message dicts + **kwargs: Additional parameters passed to the agent service + + Yields: + Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent + from trustgraph.api import AgentThought, AgentObservation, AgentAnswer + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + provenance_ids = [] + for item in flow.agent_explain( + question="What is the capital of France?", + user="trustgraph", + collection="default" + ): + if isinstance(item, AgentThought): + print(f"[Thought] {item.content}") + elif isinstance(item, AgentObservation): + print(f"[Observation] {item.content}") + elif isinstance(item, AgentAnswer): + print(f"[Answer] {item.content}") + elif isinstance(item, ProvenanceEvent): + provenance_ids.append(item.explain_id) + + # Fetch session trace after completion + if provenance_ids: + trace = explain_client.fetch_agent_trace( + provenance_ids[0], # Session URI is first + graph="urn:graph:retrieval", + user="trustgraph", + collection="default" + ) + ``` + """ + request = { + "question": question, + "user": user, + "collection": collection, + "streaming": True # Always streaming for explain + } + if state is not None: + request["state"] = state + if group is not None: + request["group"] = group + if history is not None: + request["history"] = history + request.update(kwargs) + + # Use streaming with provenance enabled + return self.client._send_request_sync( + "agent", self.flow_id, request, + streaming=True, include_provenance=True + ) + def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: """ Execute text completion with optional streaming. @@ -504,6 +757,86 @@ class SocketFlowInstance: else: return result.get("response", "") + def graph_rag_explain( + self, + query: str, + user: str, + collection: str, + max_subgraph_size: int = 1000, + max_subgraph_count: int = 5, + max_entity_distance: int = 3, + **kwargs: Any + ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: + """ + Execute graph-based RAG query with explainability support. + + Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). + Provenance events contain URIs that can be fetched using ExplainabilityClient + to get detailed information about how the response was generated. + + Args: + query: Natural language query + user: User/keyspace identifier + collection: Collection identifier + max_subgraph_size: Maximum total triples in subgraph (default: 1000) + max_subgraph_count: Maximum number of subgraphs (default: 5) + max_entity_distance: Maximum traversal depth (default: 3) + **kwargs: Additional parameters passed to the service + + Yields: + Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + provenance_ids = [] + response_text = "" + + for item in flow.graph_rag_explain( + query="Tell me about Marie Curie", + user="trustgraph", + collection="scientists" + ): + if isinstance(item, RAGChunk): + response_text += item.content + print(item.content, end='', flush=True) + elif isinstance(item, ProvenanceEvent): + provenance_ids.append(item.provenance_id) + + # Fetch explainability details + for prov_id in provenance_ids: + entity = explain_client.fetch_entity( + prov_id, + graph="urn:graph:retrieval", + user="trustgraph", + collection="scientists" + ) + print(f"Entity: {entity}") + ``` + """ + request = { + "query": query, + "user": user, + "collection": collection, + "max-subgraph-size": max_subgraph_size, + "max-subgraph-count": max_subgraph_count, + "max-entity-distance": max_entity_distance, + "streaming": True, + "explainable": True, # Enable explainability mode + } + request.update(kwargs) + + # Use streaming with provenance events included + return self.client._send_request_sync( + "graph-rag", self.flow_id, request, + streaming=True, include_provenance=True + ) + def document_rag( self, query: str, @@ -562,6 +895,79 @@ class SocketFlowInstance: else: return result.get("response", "") + def document_rag_explain( + self, + query: str, + user: str, + collection: str, + doc_limit: int = 10, + **kwargs: Any + ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: + """ + Execute document-based RAG query with explainability support. + + Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). + Provenance events contain URIs that can be fetched using ExplainabilityClient + to get detailed information about how the response was generated. + + Document RAG trace consists of: + - Question: The user's query + - Exploration: Chunks retrieved from document store (chunk_count) + - Synthesis: The generated answer + + Args: + query: Natural language query + user: User/keyspace identifier + collection: Collection identifier + doc_limit: Maximum document chunks to retrieve (default: 10) + **kwargs: Additional parameters passed to the service + + Yields: + Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events + + Example: + ```python + from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent + + socket = api.socket() + flow = socket.flow("default") + explain_client = ExplainabilityClient(flow) + + for item in flow.document_rag_explain( + query="Summarize the key findings", + user="trustgraph", + collection="research-papers", + doc_limit=5 + ): + if isinstance(item, RAGChunk): + print(item.content, end='', flush=True) + elif isinstance(item, ProvenanceEvent): + # Fetch entity details + entity = explain_client.fetch_entity( + item.explain_id, + graph=item.explain_graph, + user="trustgraph", + collection="research-papers" + ) + print(f"Event: {entity}", file=sys.stderr) + ``` + """ + request = { + "query": query, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + "streaming": True, + "explainable": True, + } + request.update(kwargs) + + # Use streaming with provenance events included + return self.client._send_request_sync( + "document-rag", self.flow_id, request, + streaming=True, include_provenance=True + ) + def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: """Generator for RAG streaming (graph-rag and document-rag)""" for chunk in result: @@ -649,12 +1055,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -682,7 +1088,7 @@ class SocketFlowInstance: **kwargs: Additional parameters passed to the service Returns: - dict: Query results with similar document chunks + dict: Query results with chunk_ids of matching document chunks Example: ```python @@ -695,14 +1101,15 @@ class SocketFlowInstance: collection="research-papers", limit=5 ) + # results contains {"chunks": [{"chunk_id": "...", "score": 0.95}, ...]} ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -711,55 +1118,57 @@ class SocketFlowInstance: return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) - def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: + def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]: """ - Generate vector embeddings for text. + Generate vector embeddings for one or more texts. Args: - text: Input text to embed + texts: List of input texts to embed **kwargs: Additional parameters passed to the service Returns: - dict: Response containing vectors + dict: Response containing vectors (one set per input text) Example: ```python socket = api.socket() flow = socket.flow("default") - result = flow.embeddings("quantum computing") + result = flow.embeddings(["quantum computing"]) vectors = result.get("vectors", []) ``` """ - request = {"text": text} + request = {"texts": texts} request.update(kwargs) return self.client._send_request_sync("embeddings", self.flow_id, request, False) def triples_query( self, - s: Optional[str] = None, - p: Optional[str] = None, - o: Optional[str] = None, + s: Optional[Union[str, Dict[str, Any]]] = None, + p: Optional[Union[str, Dict[str, Any]]] = None, + o: Optional[Union[str, Dict[str, Any]]] = None, + g: Optional[str] = None, user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, **kwargs: Any - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: """ Query knowledge graph triples using pattern matching. Args: - s: Subject URI (optional, use None for wildcard) - p: Predicate URI (optional, use None for wildcard) - o: Object URI or Literal (optional, use None for wildcard) + s: Subject filter - URI string, Term dict, or None for wildcard + p: Predicate filter - URI string, Term dict, or None for wildcard + o: Object filter - URI/literal string, Term dict, or None for wildcard + g: Named graph filter - URI string or None for all graphs user: User/keyspace identifier (optional) collection: Collection identifier (optional) limit: Maximum results to return (default: 100) **kwargs: Additional parameters passed to the service Returns: - dict: Query results with matching triples + List[Dict]: List of matching triples in wire format Example: ```python @@ -767,27 +1176,125 @@ class SocketFlowInstance: flow = socket.flow("default") # Find all triples about a specific subject - result = flow.triples_query( + triples = flow.triples_query( s="http://example.org/person/marie-curie", user="trustgraph", collection="scientists" ) + + # Query with named graph filter + triples = flow.triples_query( + s="urn:trustgraph:session:abc123", + g="urn:graph:retrieval", + user="trustgraph", + collection="default" + ) ``` """ request = {"limit": limit} - if s is not None: - request["s"] = str(s) - if p is not None: - request["p"] = str(p) - if o is not None: - request["o"] = str(o) + + # Build Term dicts for s/p/o (auto-converts strings) + s_term = build_term(s) + p_term = build_term(p) + o_term = build_term(o) + + if s_term is not None: + request["s"] = s_term + if p_term is not None: + request["p"] = p_term + if o_term is not None: + request["o"] = o_term + if g is not None: + request["g"] = g if user is not None: request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) - return self.client._send_request_sync("triples", self.flow_id, request, False) + result = self.client._send_request_sync("triples", self.flow_id, request, False) + # Return the triples list from the response + if isinstance(result, dict) and "response" in result: + return result["response"] + return result + + def triples_query_stream( + self, + s: Optional[Union[str, Dict[str, Any]]] = None, + p: Optional[Union[str, Dict[str, Any]]] = None, + o: Optional[Union[str, Dict[str, Any]]] = None, + g: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + limit: int = 100, + batch_size: int = 20, + **kwargs: Any + ) -> Iterator[List[Dict[str, Any]]]: + """ + Query knowledge graph triples with streaming batches. + + Yields batches of triples as they arrive, reducing time-to-first-result + and memory overhead for large result sets. + + Args: + s: Subject filter - URI string, Term dict, or None for wildcard + p: Predicate filter - URI string, Term dict, or None for wildcard + o: Object filter - URI/literal string, Term dict, or None for wildcard + g: Named graph filter - URI string or None for all graphs + user: User/keyspace identifier (optional) + collection: Collection identifier (optional) + limit: Maximum results to return (default: 100) + batch_size: Triples per batch (default: 20) + **kwargs: Additional parameters passed to the service + + Yields: + List[Dict]: Batches of triples in wire format + + Example: + ```python + socket = api.socket() + flow = socket.flow("default") + + for batch in flow.triples_query_stream( + user="trustgraph", + collection="default" + ): + for triple in batch: + print(triple["s"], triple["p"], triple["o"]) + ``` + """ + request = { + "limit": limit, + "streaming": True, + "batch-size": batch_size, + } + + # Build Term dicts for s/p/o (auto-converts strings) + s_term = build_term(s) + p_term = build_term(p) + o_term = build_term(o) + + if s_term is not None: + request["s"] = s_term + if p_term is not None: + request["p"] = p_term + if o_term is not None: + request["o"] = o_term + if g is not None: + request["g"] = g + if user is not None: + request["user"] = user + if collection is not None: + request["collection"] = collection + request.update(kwargs) + + # Use raw streaming - yields response dicts directly without parsing + for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True): + # Response is {"response": [...triples...]} from translator + if isinstance(response, dict) and "response" in response: + yield response["response"] + else: + yield response def rows_query( self, @@ -935,12 +1442,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors - emb_result = self.embeddings(text=text) - vectors = emb_result.get("vectors", []) + # First convert text to embedding vector + emb_result = self.embeddings(texts=[text]) + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 3b4e476e..d39310f2 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -64,6 +64,8 @@ class DocumentMetadata: metadata: List of RDF triples providing structured metadata user: User/owner identifier tags: List of tags for categorization + parent_id: Parent document ID for child documents (empty for top-level docs) + document_type: "source" for uploaded documents, "extracted" for derived content """ id : str time : datetime.datetime @@ -73,6 +75,8 @@ class DocumentMetadata: metadata : List[Triple] user : str tags : List[str] + parent_id : str = "" + document_type : str = "source" @dataclasses.dataclass class ProcessingMetadata: @@ -198,3 +202,31 @@ class RAGChunk(StreamingChunk): chunk_type: str = "rag" end_of_stream: bool = False error: Optional[Dict[str, str]] = None + +@dataclasses.dataclass +class ProvenanceEvent: + """ + Provenance event for explainability. + + Emitted during GraphRAG queries when explainable mode is enabled. + Each event represents a provenance node created during query processing. + + Attributes: + explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123) + explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval) + event_type: Type of provenance event (question, exploration, focus, synthesis) + """ + explain_id: str + explain_graph: str = "" + event_type: str = "" # Derived from explain_id + + def __post_init__(self): + # Extract event type from explain_id + if "question" in self.explain_id: + self.event_type = "question" + elif "exploration" in self.explain_id: + self.event_type = "exploration" + elif "focus" in self.explain_id: + self.event_type = "focus" + elif "synthesis" in self.explain_id: + self.event_type = "synthesis" diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 557109a2..f9f38060 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -32,6 +32,8 @@ from . agent_service import AgentService from . graph_rag_client import GraphRagClientSpec from . tool_service import ToolService from . tool_client import ToolClientSpec +from . dynamic_tool_service import DynamicToolService +from . tool_service_client import ToolServiceClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py index 03939dc3..f48fd024 100644 --- a/trustgraph-base/trustgraph/base/agent_client.py +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -4,10 +4,57 @@ from .. schema import AgentRequest, AgentResponse from .. knowledge import Uri, Literal class AgentClient(RequestResponse): - async def invoke(self, recipient, question, plan=None, state=None, - history=[], timeout=300): - - resp = await self.request( + async def invoke(self, question, plan=None, state=None, + history=[], think=None, observe=None, answer_callback=None, + timeout=300): + """ + Invoke the agent with optional streaming callbacks. + + Args: + question: The question to ask + plan: Optional plan context + state: Optional state context + history: Conversation history + think: Optional async callback(content, end_of_message) for thought chunks + observe: Optional async callback(content, end_of_message) for observation chunks + answer_callback: Optional async callback(content, end_of_message) for answer chunks + timeout: Request timeout in seconds + + Returns: + Complete answer text (accumulated from all answer chunks) + """ + accumulated_answer = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + # Handle thought chunks + if resp.chunk_type == 'thought': + if think: + await think(resp.content, resp.end_of_message) + return False # Continue receiving + + # Handle observation chunks + if resp.chunk_type == 'observation': + if observe: + await observe(resp.content, resp.end_of_message) + return False # Continue receiving + + # Handle answer chunks + if resp.chunk_type == 'answer': + if resp.content: + accumulated_answer.append(resp.content) + if answer_callback: + await answer_callback(resp.content, resp.end_of_message) + + # Complete when dialog ends + if resp.end_of_dialog: + return True + + return False # Continue receiving + + await self.request( AgentRequest( question = question, plan = plan, @@ -18,10 +65,7 @@ class AgentClient(RequestResponse): timeout=timeout, ) - if resp.error: - raise RuntimeError(resp.error.message) - - return resp.answer + return "".join(accumulated_answer) class AgentClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 2e18a933..6ba73e08 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -1,20 +1,37 @@ """ Base chunking service that provides parameter specification functionality -for chunk-size and chunk-overlap parameters +for chunk-size and chunk-overlap parameters, and librarian client for +fetching large document content. """ +import asyncio +import base64 import logging +import uuid + from .flow_processor import FlowProcessor from .parameter_spec import ParameterSpec +from .consumer import Consumer +from .producer import Producer +from .metrics import ConsumerMetrics, ProducerMetrics + +from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ..schema import librarian_request_queue, librarian_response_queue # Module logger logger = logging.getLogger(__name__) +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue + + class ChunkingService(FlowProcessor): """Base service for chunking processors with parameter specification support""" def __init__(self, **params): + id = params.get("id", "chunker") + # Call parent constructor super(ChunkingService, self).__init__(**params) @@ -27,8 +44,183 @@ class ChunkingService(FlowProcessor): ParameterSpec(name="chunk-overlap") ) + # Librarian client for fetching document content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + logger.debug("ChunkingService initialized with parameter specifications") + async def start(self): + await super(ChunkingService, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_document_content(self, document_id, user, timeout=120): + """ + Fetch document content from librarian via Pulsar. + """ + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=document_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + return response.content + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching document {document_id}") + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="chunk", title=None, timeout=120): + """ + Save a child document (chunk) to the librarian. + + Args: + doc_id: ID for the new child document + parent_id: ID of the parent document + user: User ID + content: Document content (bytes or str) + document_type: Type of document ("chunk", etc.) + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + if isinstance(content, str): + content = content.encode("utf-8") + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + request = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving chunk: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving chunk {doc_id}") + + async def get_document_text(self, doc): + """ + Get text content from a TextDocument, fetching from librarian if needed. + + Args: + doc: TextDocument with either inline text or document_id + + Returns: + str: The document text content + """ + if doc.document_id and not doc.text: + logger.info(f"Fetching document {doc.document_id} from librarian...") + content = await self.fetch_document_content( + document_id=doc.document_id, + user=doc.metadata.user, + ) + # Content is base64 encoded + if isinstance(content, str): + content = content.encode('utf-8') + text = base64.b64decode(content).decode("utf-8") + logger.info(f"Fetched {len(text)} characters from librarian") + return text + else: + return doc.text.decode("utf-8") + async def chunk_document(self, msg, consumer, flow, default_chunk_size, default_chunk_overlap): """ Extract chunk parameters from flow and return effective values @@ -59,4 +251,16 @@ class ChunkingService(FlowProcessor): @staticmethod def add_args(parser): """Add chunking service arguments to parser""" - FlowProcessor.add_args(parser) \ No newline at end of file + FlowProcessor.add_args(parser) + + parser.add_argument( + '--librarian-request-queue', + default=default_librarian_request_queue, + help=f'Librarian request queue (default: {default_librarian_request_queue})', + ) + + parser.add_argument( + '--librarian-response-queue', + default=default_librarian_response_queue, + help=f'Librarian response queue (default: {default_librarian_response_queue})', + ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index e76a6da6..dd985eab 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal logger = logging.getLogger(__name__) class DocumentEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( DocumentEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -27,6 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) + # Return ChunkMatch objects with chunk_id and score return resp.chunks class DocumentEmbeddingsClientSpec(RequestResponseSpec): diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index f04f2c60..c7aef104 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "doc-embeddings-query" +default_concurrency = 10 class DocumentEmbeddingsQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(DocumentEmbeddingsQueryService, self).__init__( **params | { "id": id } @@ -32,7 +34,8 @@ class DocumentEmbeddingsQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = DocumentEmbeddingsRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -73,7 +76,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): type = "document-embeddings-query-error", message = str(e), ), - chunks=None, + chunks=[], ) await flow("response").send(r, properties={"id": id}) @@ -83,6 +86,13 @@ class DocumentEmbeddingsQueryService(FlowProcessor): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-base/trustgraph/base/dynamic_tool_service.py b/trustgraph-base/trustgraph/base/dynamic_tool_service.py new file mode 100644 index 00000000..f3fda6dd --- /dev/null +++ b/trustgraph-base/trustgraph/base/dynamic_tool_service.py @@ -0,0 +1,184 @@ + +""" +Base class for dynamically pluggable tool services. + +Tool services are Pulsar services that can be invoked as agent tools. +They receive a ToolServiceRequest with user, config, and arguments, +and return a ToolServiceResponse with the result. + +Uses direct Pulsar topics (no flow configuration required): +- Request: non-persistent://tg/request/{topic} +- Response: non-persistent://tg/response/{topic} +""" + +import json +import logging +import asyncio +import argparse +from prometheus_client import Counter + +from .. schema import ToolServiceRequest, ToolServiceResponse, Error +from .. exceptions import TooManyRequests +from . async_processor import AsyncProcessor +from . consumer import Consumer +from . producer import Producer +from . metrics import ConsumerMetrics, ProducerMetrics + +# Module logger +logger = logging.getLogger(__name__) + +default_concurrency = 1 +default_topic = "tool" + + +class DynamicToolService(AsyncProcessor): + """ + Base class for implementing dynamic tool services. + + Subclasses should override the `invoke` method to implement + the tool's logic. + + The invoke method receives: + - user: The user context for multi-tenancy + - config: Dict of config values from the tool descriptor + - arguments: Dict of arguments from the LLM + + And should return a string response (the observation). + """ + + def __init__(self, **params): + + super(DynamicToolService, self).__init__(**params) + + self.id = params.get("id") + topic = params.get("topic", default_topic) + + # Build direct Pulsar topic paths + request_topic = f"non-persistent://tg/request/{topic}" + response_topic = f"non-persistent://tg/response/{topic}" + + logger.info(f"Tool service topics: request={request_topic}, response={response_topic}") + + # Create consumer for requests + consumer_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="request" + ) + + self.consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + subscriber=f"{self.id}-request", + flow=None, + topic=request_topic, + schema=ToolServiceRequest, + handler=self.on_request, + metrics=consumer_metrics, + ) + + # Create producer for responses + producer_metrics = ProducerMetrics( + processor=self.id, flow=None, name="response" + ) + + self.producer = Producer( + backend=self.pubsub, + topic=response_topic, + schema=ToolServiceResponse, + metrics=producer_metrics, + ) + + if not hasattr(__class__, "tool_service_metric"): + __class__.tool_service_metric = Counter( + 'dynamic_tool_service_invocation_count', + 'Dynamic tool service invocation count', + ["id"], + ) + + async def start(self): + await super(DynamicToolService, self).start() + await self.producer.start() + await self.consumer.start() + logger.info(f"Tool service {self.id} started") + + async def on_request(self, msg, consumer, flow): + + id = None + + try: + + request = msg.value() + + # Sender-produced ID for correlation + id = msg.properties().get("id", "unknown") + + # Parse the request + user = request.user or "trustgraph" + config = json.loads(request.config) if request.config else {} + arguments = json.loads(request.arguments) if request.arguments else {} + + logger.debug(f"Tool service request: user={user}, config={config}, arguments={arguments}") + + # Invoke the tool implementation + response = await self.invoke(user, config, arguments) + + # Send success response + await self.producer.send( + ToolServiceResponse( + error=None, + response=response if isinstance(response, str) else json.dumps(response), + end_of_stream=True, + ), + properties={"id": id} + ) + + __class__.tool_service_metric.labels( + id=self.id, + ).inc() + + except TooManyRequests as e: + raise e + + except Exception as e: + + logger.error(f"Exception in dynamic tool service: {e}", exc_info=True) + + logger.info("Sending error response...") + + await self.producer.send( + ToolServiceResponse( + error=Error( + type="tool-service-error", + message=str(e), + ), + response="", + end_of_stream=True, + ), + properties={"id": id if id else "unknown"} + ) + + async def invoke(self, user, config, arguments): + """ + Invoke the tool service. + + Override this method in subclasses to implement the tool's logic. + + Args: + user: The user context for multi-tenancy + config: Dict of config values from the tool descriptor + arguments: Dict of arguments from the LLM + + Returns: + A string response (the observation) or a dict/list that will be JSON-encoded + """ + raise NotImplementedError("Subclasses must implement invoke()") + + @staticmethod + def add_args(parser): + + AsyncProcessor.add_args(parser) + + parser.add_argument( + '-t', '--topic', + default=default_topic, + help=f'Topic name for request/response (default: {default_topic})' + ) diff --git a/trustgraph-base/trustgraph/base/embeddings_client.py b/trustgraph-base/trustgraph/base/embeddings_client.py index ceb08eb2..faaa192d 100644 --- a/trustgraph-base/trustgraph/base/embeddings_client.py +++ b/trustgraph-base/trustgraph/base/embeddings_client.py @@ -3,11 +3,11 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import EmbeddingsRequest, EmbeddingsResponse class EmbeddingsClient(RequestResponse): - async def embed(self, text, timeout=30): + async def embed(self, texts, timeout=300): resp = await self.request( EmbeddingsRequest( - text = text + texts = texts ), timeout=timeout ) diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py index a1442d41..7ae63521 100644 --- a/trustgraph-base/trustgraph/base/embeddings_service.py +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -65,7 +65,7 @@ class EmbeddingsService(FlowProcessor): # Pass model from request if specified (non-empty), otherwise use default model = flow("model") - vectors = await self.on_embeddings(request.text, model=model) + vectors = await self.on_embeddings(request.texts, model=model) await flow("response").send( EmbeddingsResponse( @@ -94,7 +94,7 @@ class EmbeddingsService(FlowProcessor): type = "embeddings-error", message = str(e), ), - vectors=None, + vectors=[], ), properties={"id": id} ) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index 07eb2bc7..fec82378 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -19,12 +19,12 @@ def to_value(x): return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( GraphEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return [ - to_value(v) - for v in resp.entities - ] + # Return EntityMatch objects with entity and score + return resp.entities class GraphEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index d429b3a5..cbbef4f2 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "graph-embeddings-query" +default_concurrency = 10 class GraphEmbeddingsQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(GraphEmbeddingsQueryService, self).__init__( **params | { "id": id } @@ -32,7 +34,8 @@ class GraphEmbeddingsQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = GraphEmbeddingsRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -83,6 +86,13 @@ class GraphEmbeddingsQueryService(FlowProcessor): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index c4f3f7ab..66dbad1e 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -4,20 +4,58 @@ from .. schema import GraphRagQuery, GraphRagResponse class GraphRagClient(RequestResponse): async def rag(self, query, user="trustgraph", collection="default", + chunk_callback=None, explain_callback=None, timeout=600): - resp = await self.request( + """ + Execute a graph RAG query with optional streaming callbacks. + + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional async callback(text, end_of_stream) for text chunks + explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + # Handle explain notifications + if resp.message_type == 'explain': + if explain_callback and resp.explain_id: + await explain_callback(resp.explain_id, resp.explain_graph) + return False # Continue receiving + + # Handle text chunks + if resp.message_type == 'chunk': + if resp.response: + accumulated_response.append(resp.response) + if chunk_callback: + await chunk_callback(resp.response, resp.end_of_stream) + + # Complete when session ends + if resp.end_of_session: + return True + + return False # Continue receiving + + await self.request( GraphRagQuery( query = query, user = user, collection = collection, ), - timeout=timeout + timeout=timeout, + recipient=recipient, ) - if resp.error: - raise RuntimeError(resp.error.message) - - return resp.response + return "".join(accumulated_response) class GraphRagClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index c6248622..a3c3debd 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -12,7 +12,7 @@ import logging import base64 import types from dataclasses import asdict, is_dataclass -from typing import Any +from typing import Any, get_type_hints from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message @@ -58,6 +58,7 @@ def dict_to_dataclass(data: dict, cls: type) -> Any: Convert a dictionary back to a dataclass instance. Handles nested dataclasses and missing fields. + Uses get_type_hints() to resolve forward references (string annotations). """ if data is None: return None @@ -65,8 +66,13 @@ def dict_to_dataclass(data: dict, cls: type) -> Any: if not is_dataclass(cls): return data - # Get field types from the dataclass - field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + # Get field types from the dataclass, resolving forward references + # get_type_hints() evaluates string annotations like "Triple | None" + try: + field_types = get_type_hints(cls) + except Exception: + # Fallback if get_type_hints fails (shouldn't happen normally) + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} kwargs = {} for key, value in data.items(): diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py index 0141da31..811adf40 100644 --- a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse class RowEmbeddingsQueryClient(RequestResponse): async def row_embeddings_query( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=600 ): request = RowEmbeddingsRequest( - vectors=vectors, + vector=vector, schema_name=schema_name, user=user, collection=collection, diff --git a/trustgraph-base/trustgraph/base/tool_service_client.py b/trustgraph-base/trustgraph/base/tool_service_client.py new file mode 100644 index 00000000..81930ba0 --- /dev/null +++ b/trustgraph-base/trustgraph/base/tool_service_client.py @@ -0,0 +1,90 @@ + +import json +import logging + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ToolServiceRequest, ToolServiceResponse + +logger = logging.getLogger(__name__) + + +class ToolServiceClient(RequestResponse): + """Client for invoking dynamically configured tool services.""" + + async def call(self, user, config, arguments, timeout=600): + """ + Call a tool service. + + Args: + user: User context for multi-tenancy + config: Dict of config values (e.g., {"collection": "customers"}) + arguments: Dict of arguments from LLM + timeout: Request timeout in seconds + + Returns: + Response string from the tool service + """ + resp = await self.request( + ToolServiceRequest( + user=user, + config=json.dumps(config) if config else "{}", + arguments=json.dumps(arguments) if arguments else "{}", + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.response + + async def call_streaming(self, user, config, arguments, callback, timeout=600): + """ + Call a tool service with streaming response. + + Args: + user: User context for multi-tenancy + config: Dict of config values + arguments: Dict of arguments from LLM + callback: Async function called with each response chunk + timeout: Request timeout in seconds + + Returns: + Final response string + """ + result = [] + + async def handle_response(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + if resp.response: + result.append(resp.response) + await callback(resp.response) + + return resp.end_of_stream + + await self.request( + ToolServiceRequest( + user=user, + config=json.dumps(config) if config else "{}", + arguments=json.dumps(arguments) if arguments else "{}", + ), + timeout=timeout, + recipient=handle_response + ) + + return "".join(result) + + +class ToolServiceClientSpec(RequestResponseSpec): + """Specification for a tool service client.""" + + def __init__(self, request_name, response_name): + super(ToolServiceClientSpec, self).__init__( + request_name=request_name, + request_schema=ToolServiceRequest, + response_name=response_name, + response_schema=ToolServiceResponse, + impl=ToolServiceClient, + ) diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 7258d3ca..e661f46d 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,6 +1,6 @@ from . request_response_spec import RequestResponse, RequestResponseSpec -from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL, TRIPLE from .. knowledge import Uri, Literal @@ -22,18 +22,28 @@ def to_value(x): def from_value(x): - """Convert Uri or Literal to schema Term.""" + """Convert Uri, Literal, string, or Term to schema Term.""" if x is None: return None + if isinstance(x, Term): + return x if isinstance(x, Uri): return Term(type=IRI, iri=str(x)) + elif isinstance(x, Literal): + return Term(type=LITERAL, value=str(x)) + elif isinstance(x, str): + # Detect IRIs by common prefixes + if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"): + return Term(type=IRI, iri=x) + else: + return Term(type=LITERAL, value=x) else: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): async def query(self, s=None, p=None, o=None, limit=20, user="trustgraph", collection="default", - timeout=30): + timeout=30, g=None): resp = await self.request( TriplesQueryRequest( @@ -43,6 +53,7 @@ class TriplesClient(RequestResponse): limit = limit, user = user, collection = collection, + g = g, ), timeout=timeout ) @@ -57,6 +68,64 @@ class TriplesClient(RequestResponse): return triples + async def query_stream(self, s=None, p=None, o=None, limit=20, + user="trustgraph", collection="default", + batch_size=20, timeout=30, + batch_callback=None, g=None): + """ + Streaming triple query - calls callback for each batch as it arrives. + + Args: + s, p, o: Triple pattern (None for wildcard) + limit: Maximum total triples to return + user: User/keyspace + collection: Collection name + batch_size: Triples per batch + timeout: Request timeout in seconds + batch_callback: Async callback(batch, is_final) called for each batch + g: Graph filter. ""=default graph only, None=all graphs, + or a specific graph IRI. + + Returns: + List[Triple]: All triples (flattened) if no callback provided + """ + all_triples = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + + if batch_callback: + await batch_callback(batch, resp.is_final) + else: + all_triples.extend(batch) + + return resp.is_final + + await self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + user=user, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + ) + + if not batch_callback: + return all_triples + class TriplesClientSpec(RequestResponseSpec): def __init__( self, request_name, response_name, diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index b156ef55..09f36652 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "triples-query" +default_concurrency = 10 class TriplesQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(TriplesQueryService, self).__init__(**params | { "id": id }) @@ -30,7 +32,8 @@ class TriplesQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = TriplesQueryRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -52,13 +55,23 @@ class TriplesQueryService(FlowProcessor): logger.debug(f"Handling triples query request {id}...") - triples = await self.query_triples(request) - - logger.debug("Sending triples query response...") - r = TriplesQueryResponse(triples=triples, error=None) - await flow("response").send(r, properties={"id": id}) - - logger.debug("Triples query request completed") + if request.streaming: + # Streaming mode: send batches + async for batch, is_final in self.query_triples_stream(request): + r = TriplesQueryResponse( + triples=batch, + error=None, + is_final=is_final, + ) + await flow("response").send(r, properties={"id": id}) + logger.debug("Triples query streaming completed") + else: + # Non-streaming mode: single response + triples = await self.query_triples(request) + logger.debug("Sending triples query response...") + r = TriplesQueryResponse(triples=triples, error=None) + await flow("response").send(r, properties={"id": id}) + logger.debug("Triples query request completed") except Exception as e: @@ -76,11 +89,36 @@ class TriplesQueryService(FlowProcessor): await flow("response").send(r, properties={"id": id}) + async def query_triples_stream(self, request): + """ + Streaming query - yields (batch, is_final) tuples. + Default implementation batches results from query_triples. + Override for true streaming from backend. + """ + triples = await self.query_triples(request) + batch_size = request.batch_size if request.batch_size > 0 else 20 + + for i in range(0, len(triples), batch_size): + batch = triples[i:i + batch_size] + is_final = (i + batch_size >= len(triples)) + yield batch, is_final + + # Handle empty result + if len(triples) == 0: + yield [], True + @staticmethod def add_args(parser): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index b31b4e36..17ff5a09 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -42,25 +42,59 @@ class AgentClient(BaseClient): question, think=None, observe=None, + answer_callback=None, + error_callback=None, timeout=300 ): + """ + Request an agent query with optional streaming callbacks. + + Args: + question: The question to ask + think: Optional callback(content, end_of_message) for thought chunks + observe: Optional callback(content, end_of_message) for observation chunks + answer_callback: Optional callback(content, end_of_message) for answer chunks + error_callback: Optional callback(content) for error messages + timeout: Request timeout in seconds + + Returns: + Complete answer text (accumulated from all answer chunks) + """ + accumulated_answer = [] def inspect(x): + # Handle errors + if x.chunk_type == 'error' or x.error: + if error_callback: + error_callback(x.content or (x.error.message if x.error else "")) + # Continue to check end_of_dialog - if x.thought and think: - think(x.thought) - return + # Handle thought chunks + elif x.chunk_type == 'thought': + if think: + think(x.content, x.end_of_message) - if x.observation and observe: - observe(x.observation) - return + # Handle observation chunks + elif x.chunk_type == 'observation': + if observe: + observe(x.content, x.end_of_message) - if x.answer: + # Handle answer chunks + elif x.chunk_type == 'answer': + if x.content: + accumulated_answer.append(x.content) + if answer_callback: + answer_callback(x.content, x.end_of_message) + + # Complete when dialog ends + if x.end_of_dialog: return True - return False + return False # Continue receiving - return self.call( + self.call( question=question, inspect=inspect, timeout=timeout - ).answer + ) + + return "".join(accumulated_answer) diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index 124cf3c8..1ab47aab 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).chunks diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 6cbafa9b..946b1a6c 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -40,9 +40,47 @@ class DocumentRagClient(BaseClient): output_schema=DocumentRagResponse, ) - def request(self, query, timeout=300): + def request(self, query, user="trustgraph", collection="default", + chunk_callback=None, explain_callback=None, timeout=300): + """ + Request a document RAG query with optional streaming callbacks. - return self.call( - query=query, timeout=timeout - ).response + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional callback(text, end_of_stream) for text chunks + explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + def inspect(x): + # Handle explain notifications (response is None/empty, explain_id present) + if x.explain_id and not x.response: + if explain_callback: + explain_callback(x.explain_id, x.explain_graph) + return False # Continue receiving + + # Handle text chunks + if x.response: + accumulated_response.append(x.response) + if chunk_callback: + chunk_callback(x.response, x.end_of_stream) + + # Complete when stream ends + if x.end_of_stream: + return True + + return False # Continue receiving + + self.call( + query=query, user=user, collection=collection, + inspect=inspect, timeout=timeout + ) + + return "".join(accumulated_response) diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 1a7a9512..f85c91ee 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 77102e36..42ffce0c 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -42,10 +42,50 @@ class GraphRagClient(BaseClient): def request( self, query, user="trustgraph", collection="default", + chunk_callback=None, + explain_callback=None, timeout=500 ): + """ + Request a graph RAG query with optional streaming callbacks. - return self.call( - user=user, collection=collection, query=query, timeout=timeout - ).response + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional callback(text, end_of_stream) for text chunks + explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + def inspect(x): + # Handle explain notifications + if x.message_type == 'explain': + if explain_callback and x.explain_id: + explain_callback(x.explain_id, x.explain_graph) + return False # Continue receiving + + # Handle text chunks + if x.message_type == 'chunk': + if x.response: + accumulated_response.append(x.response) + if chunk_callback: + chunk_callback(x.response, x.end_of_stream) + + # Complete when session ends + if x.end_of_session: + return True + + return False # Continue receiving + + self.call( + user=user, collection=collection, query=query, + inspect=inspect, timeout=timeout + ) + + return "".join(accumulated_response) diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 4f911e3c..19d4b338 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient): ) def request( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=300 ): kwargs = dict( user=user, collection=collection, - vectors=vectors, schema_name=schema_name, + vector=vector, schema_name=schema_name, limit=limit, timeout=timeout ) if index_name: diff --git a/trustgraph-base/trustgraph/knowledge/defs.py b/trustgraph-base/trustgraph/knowledge/defs.py index d6290930..4c7eb41a 100644 --- a/trustgraph-base/trustgraph/knowledge/defs.py +++ b/trustgraph-base/trustgraph/knowledge/defs.py @@ -26,8 +26,40 @@ KEYWORD = 'https://schema.org/keywords' class Uri(str): def is_uri(self): return True def is_literal(self): return False + def is_triple(self): return False class Literal(str): def is_uri(self): return False def is_literal(self): return True + def is_triple(self): return False + +class QuotedTriple: + """ + RDF-star quoted triple (reification). + + Represents a triple that can be used as the object of another triple, + enabling statements about statements. + + Example: + # subgraph:123 tg:contains <<:Hope skos:definition "A feeling...">> + qt = QuotedTriple( + s=Uri("https://example.org/Hope"), + p=Uri("http://www.w3.org/2004/02/skos/core#definition"), + o=Literal("A feeling of expectation") + ) + """ + def __init__(self, s, p, o): + self.s = s # Uri, Literal, or QuotedTriple + self.p = p # Uri + self.o = o # Uri, Literal, or QuotedTriple + + def is_uri(self): return False + def is_literal(self): return False + def is_triple(self): return True + + def __repr__(self): + return f"<<{self.s} {self.p} {self.o}>>" + + def __str__(self): + return f"<<{self.s} {self.p} {self.o}>>" diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 4289df0a..378bdb41 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -13,7 +13,9 @@ class AgentRequestTranslator(MessageTranslator): group=data.get("group", None), history=data.get("history", []), user=data.get("user", "trustgraph"), - streaming=data.get("streaming", False) + collection=data.get("collection", "default"), + streaming=data.get("streaming", False), + session_id=data.get("session_id", ""), ) def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: @@ -23,7 +25,9 @@ class AgentRequestTranslator(MessageTranslator): "group": obj.group, "history": obj.history, "user": obj.user, - "streaming": getattr(obj, "streaming", False) + "collection": getattr(obj, "collection", "default"), + "streaming": getattr(obj, "streaming", False), + "session_id": getattr(obj, "session_id", ""), } @@ -55,6 +59,15 @@ class AgentResponseTranslator(MessageTranslator): result["end_of_message"] = getattr(obj, "end_of_message", False) result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) + # Include explainability fields if present + explain_id = getattr(obj, "explain_id", None) + if explain_id: + result["explain_id"] = explain_id + + explain_graph = getattr(obj, "explain_graph", None) + if explain_graph is not None: + result["explain_graph"] = explain_graph + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "code": obj.error.code} diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 3dfef718..7c2a013f 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -2,190 +2,171 @@ import base64 from typing import Dict, Any from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEmbeddings from .base import SendTranslator -from .metadata import DocumentMetadataTranslator -from .primitives import SubgraphTranslator class DocumentTranslator(SendTranslator): """Translator for Document schema objects (PDF docs etc.)""" - - def __init__(self): - self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> Document: - metadata = data.get("metadata", []) - # Handle base64 content validation doc = base64.b64decode(data["data"]) - + from ...schema import Metadata return Document( metadata=Metadata( id=data.get("id"), - metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + root=data.get("root", ""), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), data=base64.b64encode(doc).decode("utf-8") ) - + def from_pulsar(self, obj: Document) -> Dict[str, Any]: result = { "data": obj.data } - + if obj.metadata: metadata_dict = {} if obj.metadata.id: metadata_dict["id"] = obj.metadata.id + if obj.metadata.root: + metadata_dict["root"] = obj.metadata.root if obj.metadata.user: metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection - if obj.metadata.metadata: - metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) - + result["metadata"] = metadata_dict - + return result class TextDocumentTranslator(SendTranslator): """Translator for TextDocument schema objects""" - - def __init__(self): - self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> TextDocument: - metadata = data.get("metadata", []) charset = data.get("charset", "utf-8") - + # Text is base64 encoded in input text = base64.b64decode(data["text"]).decode(charset) - + from ...schema import Metadata return TextDocument( metadata=Metadata( id=data.get("id"), - metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + root=data.get("root", ""), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), text=text.encode("utf-8") ) - + def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]: result = { "text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text } - + if obj.metadata: metadata_dict = {} if obj.metadata.id: metadata_dict["id"] = obj.metadata.id + if obj.metadata.root: + metadata_dict["root"] = obj.metadata.root if obj.metadata.user: metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection - if obj.metadata.metadata: - metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) - + result["metadata"] = metadata_dict - + return result class ChunkTranslator(SendTranslator): """Translator for Chunk schema objects""" - - def __init__(self): - self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> Chunk: - metadata = data.get("metadata", []) - from ...schema import Metadata return Chunk( metadata=Metadata( id=data.get("id"), - metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + root=data.get("root", ""), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] ) - + def from_pulsar(self, obj: Chunk) -> Dict[str, Any]: result = { "chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk } - + if obj.metadata: metadata_dict = {} if obj.metadata.id: metadata_dict["id"] = obj.metadata.id + if obj.metadata.root: + metadata_dict["root"] = obj.metadata.root if obj.metadata.user: metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection - if obj.metadata.metadata: - metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) - + result["metadata"] = metadata_dict - + return result class DocumentEmbeddingsTranslator(SendTranslator): """Translator for DocumentEmbeddings schema objects""" - - def __init__(self): - self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings: metadata = data.get("metadata", {}) - + chunks = [ ChunkEmbeddings( - chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"], + chunk_id=chunk["chunk_id"], vectors=chunk["vectors"] ) for chunk in data.get("chunks", []) ] - + from ...schema import Metadata return DocumentEmbeddings( metadata=Metadata( id=metadata.get("id"), - metadata=self.subgraph_translator.to_pulsar(metadata.get("metadata", [])), + root=metadata.get("root", ""), user=metadata.get("user", "trustgraph"), collection=metadata.get("collection", "default"), ), chunks=chunks ) - + def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]: result = { "chunks": [ { - "chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk, - "vectors": chunk.vectors + "chunk_id": chunk.chunk_id, + "vector": chunk.vector } for chunk in obj.chunks ] } - + if obj.metadata: metadata_dict = {} if obj.metadata.id: metadata_dict["id"] = obj.metadata.id + if obj.metadata.root: + metadata_dict["root"] = obj.metadata.root if obj.metadata.user: metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection - if obj.metadata.metadata: - metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) - + result["metadata"] = metadata_dict - + return result \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings.py b/trustgraph-base/trustgraph/messaging/translators/embeddings.py index 7e6eff83..454ce733 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings.py @@ -5,15 +5,15 @@ from .base import MessageTranslator class EmbeddingsRequestTranslator(MessageTranslator): """Translator for EmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest: return EmbeddingsRequest( - text=data["text"] + texts=data["texts"] ) - + def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]: return { - "text": obj.text + "texts": obj.texts } diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index 141a7330..f10ca4c6 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -10,18 +10,18 @@ from .primitives import ValueTranslator class DocumentEmbeddingsRequestTranslator(MessageTranslator): """Translator for DocumentEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: return DocumentEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -30,21 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator): """Translator for DocumentEmbeddingsResponse schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} - + if obj.chunks is not None: result["chunks"] = [ - chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + { + "chunk_id": chunk.chunk_id, + "score": chunk.score + } for chunk in obj.chunks ] - + return result - + def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -52,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): class GraphEmbeddingsRequestTranslator(MessageTranslator): """Translator for GraphEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: return GraphEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -72,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): class GraphEmbeddingsResponseTranslator(MessageTranslator): """Translator for GraphEmbeddingsResponse schema objects""" - + def __init__(self): self.value_translator = ValueTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: result = {} - + if obj.entities is not None: result["entities"] = [ - self.value_translator.from_pulsar(entity) - for entity in obj.entities + { + "entity": self.value_translator.from_pulsar(match.entity), + "score": match.score + } + for match in obj.entities ] - + return result - + def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -100,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: return RowEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), @@ -110,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: result = { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection, diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 5377cbd4..0043d1e4 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -1,43 +1,38 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( - KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, Metadata, EntityEmbeddings ) from .base import MessageTranslator from .primitives import ValueTranslator, SubgraphTranslator -from .metadata import DocumentMetadataTranslator class KnowledgeRequestTranslator(MessageTranslator): """Translator for KnowledgeRequest schema objects""" - + def __init__(self): self.value_translator = ValueTranslator() self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest: triples = None if "triples" in data: triples = Triples( metadata=Metadata( id=data["triples"]["metadata"]["id"], - metadata=self.subgraph_translator.to_pulsar( - data["triples"]["metadata"]["metadata"] - ), + root=data["triples"]["metadata"].get("root", ""), user=data["triples"]["metadata"]["user"], collection=data["triples"]["metadata"]["collection"] ), triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]), ) - + graph_embeddings = None if "graph-embeddings" in data: graph_embeddings = GraphEmbeddings( metadata=Metadata( id=data["graph-embeddings"]["metadata"]["id"], - metadata=self.subgraph_translator.to_pulsar( - data["graph-embeddings"]["metadata"]["metadata"] - ), + root=data["graph-embeddings"]["metadata"].get("root", ""), user=data["graph-embeddings"]["metadata"]["user"], collection=data["graph-embeddings"]["metadata"]["collection"] ), @@ -49,7 +44,7 @@ class KnowledgeRequestTranslator(MessageTranslator): for ent in data["graph-embeddings"]["entities"] ] ) - + return KnowledgeRequest( operation=data.get("operation"), user=data.get("user"), @@ -59,10 +54,10 @@ class KnowledgeRequestTranslator(MessageTranslator): triples=triples, graph_embeddings=graph_embeddings, ) - + def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]: result = {} - + if obj.operation: result["operation"] = obj.operation if obj.user: @@ -73,99 +68,91 @@ class KnowledgeRequestTranslator(MessageTranslator): result["flow"] = obj.flow if obj.collection: result["collection"] = obj.collection - + if obj.triples: result["triples"] = { "metadata": { "id": obj.triples.metadata.id, - "metadata": self.subgraph_translator.from_pulsar( - obj.triples.metadata.metadata - ), + "root": obj.triples.metadata.root, "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), } - + if obj.graph_embeddings: result["graph-embeddings"] = { "metadata": { "id": obj.graph_embeddings.metadata.id, - "metadata": self.subgraph_translator.from_pulsar( - obj.graph_embeddings.metadata.metadata - ), + "root": obj.graph_embeddings.metadata.root, "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ { - "vectors": entity.vectors, + "vector": entity.vector, "entity": self.value_translator.from_pulsar(entity.entity), } for entity in obj.graph_embeddings.entities ], } - + return result class KnowledgeResponseTranslator(MessageTranslator): """Translator for KnowledgeResponse schema objects""" - + def __init__(self): self.value_translator = ValueTranslator() self.subgraph_translator = SubgraphTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]: # Response to list operation if obj.ids is not None: return {"ids": obj.ids} - + # Streaming triples response if obj.triples: return { "triples": { "metadata": { "id": obj.triples.metadata.id, - "metadata": self.subgraph_translator.from_pulsar( - obj.triples.metadata.metadata - ), + "root": obj.triples.metadata.root, "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), } } - + # Streaming graph embeddings response if obj.graph_embeddings: return { "graph-embeddings": { "metadata": { "id": obj.graph_embeddings.metadata.id, - "metadata": self.subgraph_translator.from_pulsar( - obj.graph_embeddings.metadata.metadata - ), + "root": obj.graph_embeddings.metadata.root, "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ { - "vectors": entity.vectors, + "vector": entity.vector, "entity": self.value_translator.from_pulsar(entity.entity), } for entity in obj.graph_embeddings.entities ], } } - + # End of stream marker if obj.eos is True: return {"eos": True} - + # Empty response (successful delete) return {} diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index fc355dda..c7e849aa 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -44,14 +44,21 @@ class LibraryRequestTranslator(MessageTranslator): return LibrarianRequest( operation=data.get("operation"), - document_id=data.get("document-id"), - processing_id=data.get("processing-id"), + document_id=data.get("document-id", ""), + processing_id=data.get("processing-id", ""), document_metadata=doc_metadata, processing_metadata=proc_metadata, content=content, - user=data.get("user"), - collection=data.get("collection"), - criteria=criteria + user=data.get("user", ""), + collection=data.get("collection", ""), + criteria=criteria, + # Chunked upload fields + total_size=data.get("total-size", 0), + chunk_size=data.get("chunk-size", 0), + upload_id=data.get("upload-id", ""), + chunk_index=data.get("chunk-index", 0), + # List documents filtering + include_children=data.get("include-children", False), ) def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]: @@ -98,27 +105,73 @@ class LibraryResponseTranslator(MessageTranslator): def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]: result = {} - + + if obj.error: + result["error"] = { + "type": obj.error.type, + "message": obj.error.message, + } + if obj.document_metadata: result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata) - + if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content - + if obj.document_metadatas is not None: result["document-metadatas"] = [ self.doc_metadata_translator.from_pulsar(dm) for dm in obj.document_metadatas ] - + if obj.processing_metadatas is not None: result["processing-metadatas"] = [ self.proc_metadata_translator.from_pulsar(pm) for pm in obj.processing_metadatas ] - + + # Chunked upload response fields + if obj.upload_id: + result["upload-id"] = obj.upload_id + if obj.chunk_size: + result["chunk-size"] = obj.chunk_size + if obj.total_chunks: + result["total-chunks"] = obj.total_chunks + if obj.chunk_index: + result["chunk-index"] = obj.chunk_index + if obj.chunks_received: + result["chunks-received"] = obj.chunks_received + if obj.bytes_received: + result["bytes-received"] = obj.bytes_received + if obj.total_bytes: + result["total-bytes"] = obj.total_bytes + if obj.document_id: + result["document-id"] = obj.document_id + if obj.object_id: + result["object-id"] = obj.object_id + if obj.upload_state: + result["upload-state"] = obj.upload_state + if obj.received_chunks: + result["received-chunks"] = obj.received_chunks + if obj.missing_chunks: + result["missing-chunks"] = obj.missing_chunks + if obj.upload_sessions: + result["upload-sessions"] = [ + { + "upload-id": s.upload_id, + "document-id": s.document_id, + "document-metadata-json": s.document_metadata_json, + "total-size": s.total_size, + "chunk-size": s.chunk_size, + "total-chunks": s.total_chunks, + "chunks-received": s.chunks_received, + "created-at": s.created_at, + } + for s in obj.upload_sessions + ] + return result def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.from_pulsar(obj), obj.is_final diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 006b222c..46a28d0a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -20,12 +20,14 @@ class DocumentMetadataTranslator(Translator): comments=data.get("comments"), metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [], user=data.get("user"), - tags=data.get("tags") + tags=data.get("tags"), + parent_id=data.get("parent-id", ""), + document_type=data.get("document-type", "source"), ) def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]: result = {} - + if obj.id: result["id"] = obj.id if obj.time: @@ -42,7 +44,11 @@ class DocumentMetadataTranslator(Translator): result["user"] = obj.user if obj.tags is not None: result["tags"] = obj.tags - + if obj.parent_id: + result["parent-id"] = obj.parent_id + if obj.document_type: + result["document-type"] = obj.document_type + return result diff --git a/trustgraph-base/trustgraph/messaging/translators/primitives.py b/trustgraph-base/trustgraph/messaging/translators/primitives.py index 790ae8f7..d54efc49 100644 --- a/trustgraph-base/trustgraph/messaging/translators/primitives.py +++ b/trustgraph-base/trustgraph/messaging/translators/primitives.py @@ -82,6 +82,7 @@ def _triple_translator_to_pulsar(data: Dict[str, Any]) -> Triple: def _triple_translator_from_pulsar(obj: Triple) -> Dict[str, Any]: + """Convert Triple object to wire format dict.""" term_translator = TermTranslator() result: Dict[str, Any] = {} diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 22166bd9..b7ff818c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -34,13 +34,31 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: result = {} - # Include response content (even if empty string) + # Include message_type for distinguishing chunk vs explain messages + message_type = getattr(obj, "message_type", "") + if message_type: + result["message_type"] = message_type + + # Include response content for chunk messages if obj.response is not None: result["response"] = obj.response - # Include end_of_stream flag + # Include explain_id for explain messages + explain_id = getattr(obj, "explain_id", None) + if explain_id: + result["explain_id"] = explain_id + + # Include explain_graph for explain messages (named graph filter) + explain_graph = getattr(obj, "explain_graph", None) + if explain_graph is not None: + result["explain_graph"] = explain_graph + + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) + # Include end_of_session flag (entire session complete) + result["end_of_session"] = getattr(obj, "end_of_session", False) + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} @@ -49,7 +67,8 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - is_final = getattr(obj, 'end_of_stream', False) + # Session is complete when end_of_session is True + is_final = getattr(obj, 'end_of_session', False) return self.from_pulsar(obj), is_final @@ -65,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator): triple_limit=int(data.get("triple-limit", 30)), max_subgraph_size=int(data.get("max-subgraph-size", 1000)), max_path_length=int(data.get("max-path-length", 2)), + edge_limit=int(data.get("edge-limit", 25)), streaming=data.get("streaming", False) ) @@ -77,6 +97,7 @@ class GraphRagRequestTranslator(MessageTranslator): "triple-limit": obj.triple_limit, "max-subgraph-size": obj.max_subgraph_size, "max-path-length": obj.max_path_length, + "edge-limit": obj.edge_limit, "streaming": getattr(obj, "streaming", False) } @@ -90,13 +111,31 @@ class GraphRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: result = {} - # Include response content (even if empty string) + # Include message_type + message_type = getattr(obj, "message_type", "") + if message_type: + result["message_type"] = message_type + + # Include response content for chunk messages if obj.response is not None: result["response"] = obj.response - # Include end_of_stream flag + # Include explain_id for explain messages + explain_id = getattr(obj, "explain_id", None) + if explain_id: + result["explain_id"] = explain_id + + # Include explain_graph for explain messages (named graph filter) + explain_graph = getattr(obj, "explain_graph", None) + if explain_graph is not None: + result["explain_graph"] = explain_graph + + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) + # Include end_of_session flag (entire session complete) + result["end_of_session"] = getattr(obj, "end_of_session", False) + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} @@ -105,5 +144,6 @@ class GraphRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - is_final = getattr(obj, 'end_of_stream', False) + # Session is complete when end_of_session is True + is_final = getattr(obj, 'end_of_session', False) return self.from_pulsar(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 2b01b1bc..2f29aa56 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator): g=g, limit=int(data.get("limit", 10000)), user=data.get("user", "trustgraph"), - collection=data.get("collection", "default") + collection=data.get("collection", "default"), + streaming=data.get("streaming", False), + batch_size=int(data.get("batch-size", 20)), ) def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]: result = { "limit": obj.limit, "user": obj.user, - "collection": obj.collection + "collection": obj.collection, + "streaming": obj.streaming, + "batch-size": obj.batch_size, } if obj.s: @@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.from_pulsar(obj), obj.is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py new file mode 100644 index 00000000..18ecb0e8 --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -0,0 +1,221 @@ +""" +Provenance module for extraction-time provenance support. + +Provides helpers for: +- URI generation for documents, pages, chunks, activities, subgraphs +- PROV-O triple building for provenance metadata +- Vocabulary bootstrap for per-collection initialization + +Usage example: + + from trustgraph.provenance import ( + document_uri, page_uri, chunk_uri_from_page, + document_triples, derived_entity_triples, + get_vocabulary_triples, + ) + + # Generate URIs + doc_uri = document_uri("my-doc-123") + page_uri = page_uri("my-doc-123", page_number=1) + + # Build provenance triples + triples = document_triples( + doc_uri, + title="My Document", + mime_type="application/pdf", + page_count=10, + ) + + # Get vocabulary bootstrap triples (once per collection) + vocab_triples = get_vocabulary_triples() +""" + +# URI generation +from . uris import ( + TRUSTGRAPH_BASE, + document_uri, + page_uri, + chunk_uri_from_page, + chunk_uri_from_doc, + activity_uri, + subgraph_uri, + agent_uri, + # Query-time provenance URIs (GraphRAG) + question_uri, + grounding_uri, + exploration_uri, + focus_uri, + synthesis_uri, + # Agent provenance URIs + agent_session_uri, + agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, + agent_final_uri, + # Document RAG provenance URIs + docrag_question_uri, + docrag_grounding_uri, + docrag_exploration_uri, + docrag_synthesis_uri, +) + +# Namespace constants +from . namespaces import ( + # PROV-O + PROV, PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + # Dublin Core + DC, DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + # RDF/RDFS + RDF, RDF_TYPE, RDFS, RDFS_LABEL, + # TrustGraph + TG, TG_CONTAINS, TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER, + TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, + TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, + TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL, + TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH, + # Extraction provenance entity types + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + # Query-time provenance predicates (GraphRAG) + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, + # Query-time provenance predicates (DocumentRAG) + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Explainability entity types + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANALYSIS, TG_CONCLUSION, + # Unifying types + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + # Question subtypes (to distinguish retrieval mechanism) + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, + # Agent provenance predicates + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + # Document reference predicate + TG_DOCUMENT, + # Named graphs + GRAPH_DEFAULT, GRAPH_SOURCE, GRAPH_RETRIEVAL, +) + +# Triple builders +from . triples import ( + document_triples, + derived_entity_triples, + subgraph_provenance_triples, + # Query-time provenance triple builders (GraphRAG) + question_triples, + grounding_triples, + exploration_triples, + focus_triples, + synthesis_triples, + # Query-time provenance triple builders (DocumentRAG) + docrag_question_triples, + docrag_exploration_triples, + docrag_synthesis_triples, + # Utility + set_graph, +) + +# Agent provenance triple builders +from . agent import ( + agent_session_triples, + agent_iteration_triples, + agent_final_triples, +) + +# Vocabulary bootstrap +from . vocabulary import ( + get_vocabulary_triples, + PROV_CLASS_LABELS, + PROV_PREDICATE_LABELS, + DC_PREDICATE_LABELS, + TG_CLASS_LABELS, + TG_PREDICATE_LABELS, +) + +__all__ = [ + # URIs + "TRUSTGRAPH_BASE", + "document_uri", + "page_uri", + "chunk_uri_from_page", + "chunk_uri_from_doc", + "activity_uri", + "subgraph_uri", + "agent_uri", + # Query-time provenance URIs + "question_uri", + "grounding_uri", + "exploration_uri", + "focus_uri", + "synthesis_uri", + # Agent provenance URIs + "agent_session_uri", + "agent_iteration_uri", + "agent_thought_uri", + "agent_observation_uri", + "agent_final_uri", + # Document RAG provenance URIs + "docrag_question_uri", + "docrag_grounding_uri", + "docrag_exploration_uri", + "docrag_synthesis_uri", + # Namespaces + "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", + "PROV_WAS_DERIVED_FROM", "PROV_WAS_GENERATED_BY", + "PROV_USED", "PROV_WAS_ASSOCIATED_WITH", "PROV_STARTED_AT_TIME", + "DC", "DC_TITLE", "DC_SOURCE", "DC_DATE", "DC_CREATOR", + "RDF", "RDF_TYPE", "RDFS", "RDFS_LABEL", + "TG", "TG_CONTAINS", "TG_PAGE_COUNT", "TG_MIME_TYPE", "TG_PAGE_NUMBER", + "TG_CHUNK_INDEX", "TG_CHAR_OFFSET", "TG_CHAR_LENGTH", + "TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION", + "TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL", + "TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH", + # Extraction provenance entity types + "TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE", + # Query-time provenance predicates (GraphRAG) + "TG_QUERY", "TG_CONCEPT", "TG_ENTITY", + "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", + # Query-time provenance predicates (DocumentRAG) + "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", + # Explainability entity types + "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", + "TG_ANALYSIS", "TG_CONCLUSION", + # Unifying types + "TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE", + # Question subtypes + "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", + # Agent provenance predicates + "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", + # Document reference predicate + "TG_DOCUMENT", + # Named graphs + "GRAPH_DEFAULT", "GRAPH_SOURCE", "GRAPH_RETRIEVAL", + # Triple builders + "document_triples", + "derived_entity_triples", + "subgraph_provenance_triples", + # Query-time provenance triple builders (GraphRAG) + "question_triples", + "grounding_triples", + "exploration_triples", + "focus_triples", + "synthesis_triples", + # Query-time provenance triple builders (DocumentRAG) + "docrag_question_triples", + "docrag_exploration_triples", + "docrag_synthesis_triples", + # Agent provenance triple builders + "agent_session_triples", + "agent_iteration_triples", + "agent_final_triples", + # Utility + "set_graph", + # Vocabulary + "get_vocabulary_triples", + "PROV_CLASS_LABELS", + "PROV_PREDICATE_LABELS", + "DC_PREDICATE_LABELS", + "TG_CLASS_LABELS", + "TG_PREDICATE_LABELS", +] diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py new file mode 100644 index 00000000..f1aeab0d --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -0,0 +1,205 @@ +""" +Helper functions to build PROV-O triples for agent provenance. + +Agent provenance tracks the reasoning trace of ReAct agent sessions: +- Question: The root activity with query and timestamp +- Analysis: Each think/act/observe cycle +- Conclusion: The final answer +""" + +import json +from datetime import datetime +from typing import List, Optional, Dict, Any + +from .. schema import Triple, Term, IRI, LITERAL + +from . namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_AGENT_QUESTION, +) + + +def _iri(uri: str) -> Term: + """Create an IRI term.""" + return Term(type=IRI, iri=uri) + + +def _literal(value) -> Term: + """Create a literal term.""" + return Term(type=LITERAL, value=str(value)) + + +def _triple(s: str, p: str, o_term: Term) -> Triple: + """Create a triple with IRI subject and predicate.""" + return Triple(s=_iri(s), p=_iri(p), o=o_term) + + +def agent_session_triples( + session_uri: str, + query: str, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for an agent session start (Question). + + Creates: + - Activity declaration with tg:Question type + - Query text and timestamp + + Args: + session_uri: URI of the session (from agent_session_uri) + query: The user's query text + timestamp: ISO timestamp (defaults to now) + + Returns: + List of Triple objects + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + return [ + _triple(session_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(session_uri, RDF_TYPE, _iri(TG_QUESTION)), + _triple(session_uri, RDF_TYPE, _iri(TG_AGENT_QUESTION)), + _triple(session_uri, RDFS_LABEL, _literal("Agent Question")), + _triple(session_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(session_uri, TG_QUERY, _literal(query)), + ] + + +def agent_iteration_triples( + iteration_uri: str, + question_uri: Optional[str] = None, + previous_uri: Optional[str] = None, + action: str = "", + arguments: Dict[str, Any] = None, + thought_uri: Optional[str] = None, + thought_document_id: Optional[str] = None, + observation_uri: Optional[str] = None, + observation_document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for one agent iteration (Analysis - think/act/observe cycle). + + Creates: + - Entity declaration with tg:Analysis type + - wasGeneratedBy link to question (if first iteration) + - wasDerivedFrom link to previous iteration (if not first) + - Action and arguments metadata + - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document + - Observation sub-entity (tg:Reflection, tg:Observation) with librarian document + + Args: + iteration_uri: URI of this iteration (from agent_iteration_uri) + question_uri: URI of the question activity (for first iteration) + previous_uri: URI of the previous iteration (for subsequent iterations) + action: The tool/action name + arguments: Arguments passed to the tool (will be JSON-encoded) + thought_uri: URI for the thought sub-entity + thought_document_id: Document URI for thought in librarian + observation_uri: URI for the observation sub-entity + observation_document_id: Document URI for observation in librarian + + Returns: + List of Triple objects + """ + if arguments is None: + arguments = {} + + triples = [ + _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), + _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), + _triple(iteration_uri, TG_ACTION, _literal(action)), + _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), + ] + + if question_uri: + triples.append( + _triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + ) + elif previous_uri: + triples.append( + _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)) + ) + + # Thought sub-entity + if thought_uri: + triples.extend([ + _triple(iteration_uri, TG_THOUGHT, _iri(thought_uri)), + _triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), + _triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)), + _triple(thought_uri, RDFS_LABEL, _literal("Thought")), + _triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), + ]) + if thought_document_id: + triples.append( + _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) + ) + + # Observation sub-entity + if observation_uri: + triples.extend([ + _triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)), + _triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), + _triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)), + _triple(observation_uri, RDFS_LABEL, _literal("Observation")), + _triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), + ]) + if observation_document_id: + triples.append( + _triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id)) + ) + + return triples + + +def agent_final_triples( + final_uri: str, + question_uri: Optional[str] = None, + previous_uri: Optional[str] = None, + document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for an agent final answer (Conclusion). + + Creates: + - Entity declaration with tg:Conclusion and tg:Answer types + - wasGeneratedBy link to question (if no iterations) + - wasDerivedFrom link to last iteration (if iterations exist) + - Document reference to librarian + + Args: + final_uri: URI of the final answer (from agent_final_uri) + question_uri: URI of the question activity (if no iterations) + previous_uri: URI of the last iteration (if iterations exist) + document_id: Librarian document ID for the answer content + + Returns: + List of Triple objects + """ + triples = [ + _triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)), + _triple(final_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(final_uri, RDFS_LABEL, _literal("Conclusion")), + ] + + if question_uri: + triples.append( + _triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + ) + elif previous_uri: + triples.append( + _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)) + ) + + if document_id: + triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) + + return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py new file mode 100644 index 00000000..066e893b --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -0,0 +1,111 @@ +""" +RDF namespace constants for provenance. + +Includes PROV-O, Dublin Core, and TrustGraph namespace URIs. +""" + +# PROV-O namespace (W3C Provenance Ontology) +PROV = "http://www.w3.org/ns/prov#" +PROV_ENTITY = PROV + "Entity" +PROV_ACTIVITY = PROV + "Activity" +PROV_AGENT = PROV + "Agent" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" +PROV_USED = PROV + "used" +PROV_WAS_ASSOCIATED_WITH = PROV + "wasAssociatedWith" +PROV_STARTED_AT_TIME = PROV + "startedAtTime" + +# Dublin Core namespace +DC = "http://purl.org/dc/elements/1.1/" +DC_TITLE = DC + "title" +DC_SOURCE = DC + "source" +DC_DATE = DC + "date" +DC_CREATOR = DC + "creator" + +# RDF/RDFS namespace (also in rdf.py, but included here for completeness) +RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDF_TYPE = RDF + "type" +RDFS = "http://www.w3.org/2000/01/rdf-schema#" +RDFS_LABEL = RDFS + "label" + +# Schema.org namespace +SCHEMA = "https://schema.org/" +SCHEMA_DIGITAL_DOCUMENT = SCHEMA + "DigitalDocument" +SCHEMA_DESCRIPTION = SCHEMA + "description" +SCHEMA_KEYWORDS = SCHEMA + "keywords" +SCHEMA_NAME = SCHEMA + "name" + +# SKOS namespace +SKOS = "http://www.w3.org/2004/02/skos/core#" +SKOS_DEFINITION = SKOS + "definition" + +# TrustGraph namespace for custom predicates +TG = "https://trustgraph.ai/ns/" +TG_CONTAINS = TG + "contains" +TG_PAGE_COUNT = TG + "pageCount" +TG_MIME_TYPE = TG + "mimeType" +TG_PAGE_NUMBER = TG + "pageNumber" +TG_CHUNK_INDEX = TG + "chunkIndex" +TG_CHAR_OFFSET = TG + "charOffset" +TG_CHAR_LENGTH = TG + "charLength" +TG_CHUNK_SIZE = TG + "chunkSize" +TG_CHUNK_OVERLAP = TG + "chunkOverlap" +TG_COMPONENT_VERSION = TG + "componentVersion" +TG_LLM_MODEL = TG + "llmModel" +TG_ONTOLOGY = TG + "ontology" +TG_EMBEDDING_MODEL = TG + "embeddingModel" +TG_SOURCE_TEXT = TG + "sourceText" +TG_SOURCE_CHAR_OFFSET = TG + "sourceCharOffset" +TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength" + +# Query-time provenance predicates (GraphRAG) +TG_QUERY = TG + "query" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_DOCUMENT = TG + "document" # Reference to document in librarian + +# Query-time provenance predicates (DocumentRAG) +TG_CHUNK_COUNT = TG + "chunkCount" +TG_SELECTED_CHUNK = TG + "selectedChunk" + +# Extraction provenance entity types +TG_DOCUMENT_TYPE = TG + "Document" +TG_PAGE_TYPE = TG + "Page" +TG_CHUNK_TYPE = TG + "Chunk" +TG_SUBGRAPH_TYPE = TG + "Subgraph" + +# Explainability entity types (shared) +TG_QUESTION = TG + "Question" +TG_GROUNDING = TG + "Grounding" +TG_EXPLORATION = TG + "Exploration" +TG_FOCUS = TG + "Focus" +TG_SYNTHESIS = TG + "Synthesis" +TG_ANALYSIS = TG + "Analysis" +TG_CONCLUSION = TG + "Conclusion" + +# Unifying types for answer and intermediate commentary +TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion) +TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation) +TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning +TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result + +# Question subtypes (to distinguish retrieval mechanism) +TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" +TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" +TG_AGENT_QUESTION = TG + "AgentQuestion" + +# Agent provenance predicates +TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity +TG_ACTION = TG + "action" +TG_ARGUMENTS = TG + "arguments" +TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity + +# Named graph URIs for RDF datasets +# These separate different types of data while keeping them in the same collection +GRAPH_DEFAULT = "" # Core knowledge facts (triples extracted from documents) +GRAPH_SOURCE = "urn:graph:source" # Extraction provenance (which document/chunk a triple came from) +GRAPH_RETRIEVAL = "urn:graph:retrieval" # Query-time explainability (question, exploration, focus, synthesis) diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py new file mode 100644 index 00000000..60d8d8f6 --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -0,0 +1,646 @@ +""" +Helper functions to build PROV-O triples for extraction-time provenance. +""" + +from datetime import datetime +from typing import List, Optional + +from .. schema import Triple, Term, IRI, LITERAL, TRIPLE + +from . namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER, + TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, + TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, + TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS, + # Extraction provenance entity types + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + # Query-time provenance predicates (GraphRAG) + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, + # Query-time provenance predicates (DocumentRAG) + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Explainability entity types + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + # Unifying types + TG_ANSWER_TYPE, + # Question subtypes + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, +) + +from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri + + +def set_graph(triples: List[Triple], graph: str) -> List[Triple]: + """ + Set the named graph on a list of triples. + + This creates new Triple objects with the graph field set, + leaving the original triples unchanged. + + Args: + triples: List of Triple objects + graph: Named graph URI (e.g., "urn:graph:retrieval") + + Returns: + List of Triple objects with graph field set + """ + return [ + Triple(s=t.s, p=t.p, o=t.o, g=graph) + for t in triples + ] + + +def _iri(uri: str) -> Term: + """Create an IRI term.""" + return Term(type=IRI, iri=uri) + + +def _literal(value) -> Term: + """Create a literal term.""" + return Term(type=LITERAL, value=str(value)) + + +def _triple(s: str, p: str, o_term: Term) -> Triple: + """Create a triple with IRI subject and predicate.""" + return Triple(s=_iri(s), p=_iri(p), o=o_term) + + +def document_triples( + doc_uri: str, + title: Optional[str] = None, + source: Optional[str] = None, + date: Optional[str] = None, + creator: Optional[str] = None, + page_count: Optional[int] = None, + mime_type: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a source document entity. + + Args: + doc_uri: The document URI (from uris.document_uri) + title: Document title + source: Source URL/path + date: Document date + creator: Author/creator + page_count: Number of pages (for PDFs) + mime_type: MIME type + + Returns: + List of Triple objects + """ + triples = [ + _triple(doc_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(doc_uri, RDF_TYPE, _iri(TG_DOCUMENT_TYPE)), + ] + + if title: + triples.append(_triple(doc_uri, DC_TITLE, _literal(title))) + triples.append(_triple(doc_uri, RDFS_LABEL, _literal(title))) + + if source: + triples.append(_triple(doc_uri, DC_SOURCE, _iri(source))) + + if date: + triples.append(_triple(doc_uri, DC_DATE, _literal(date))) + + if creator: + triples.append(_triple(doc_uri, DC_CREATOR, _literal(creator))) + + if page_count is not None: + triples.append(_triple(doc_uri, TG_PAGE_COUNT, _literal(page_count))) + + if mime_type: + triples.append(_triple(doc_uri, TG_MIME_TYPE, _literal(mime_type))) + + return triples + + +def derived_entity_triples( + entity_uri: str, + parent_uri: str, + component_name: str, + component_version: str, + label: Optional[str] = None, + page_number: Optional[int] = None, + chunk_index: Optional[int] = None, + char_offset: Optional[int] = None, + char_length: Optional[int] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a derived entity (page or chunk) with full PROV-O provenance. + + Creates: + - Entity declaration + - wasDerivedFrom relationship to parent + - Activity for the extraction + - Agent for the component + + Args: + entity_uri: URI of the derived entity (page or chunk) + parent_uri: URI of the parent entity + component_name: Name of TG component (e.g., "pdf-extractor", "chunker") + component_version: Version of the component + label: Human-readable label + page_number: Page number (for pages) + chunk_index: Chunk index (for chunks) + char_offset: Character offset in parent (for chunks) + char_length: Character length (for chunks) + chunk_size: Configured chunk size (for chunking activity) + chunk_overlap: Configured chunk overlap (for chunking activity) + timestamp: ISO timestamp (defaults to now) + + Returns: + List of Triple objects + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + act_uri = activity_uri() + agt_uri = agent_uri(component_name) + + # Determine specific type from parameters + if page_number is not None: + specific_type = TG_PAGE_TYPE + elif chunk_index is not None: + specific_type = TG_CHUNK_TYPE + else: + specific_type = None + + triples = [ + # Entity declaration + _triple(entity_uri, RDF_TYPE, _iri(PROV_ENTITY)), + ] + + if specific_type: + triples.append(_triple(entity_uri, RDF_TYPE, _iri(specific_type))) + + triples.extend([ + # Derivation from parent + _triple(entity_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), + + # Generation by activity + _triple(entity_uri, PROV_WAS_GENERATED_BY, _iri(act_uri)), + + # Activity declaration + _triple(act_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(act_uri, RDFS_LABEL, _literal(f"{component_name} extraction")), + _triple(act_uri, PROV_USED, _iri(parent_uri)), + _triple(act_uri, PROV_WAS_ASSOCIATED_WITH, _iri(agt_uri)), + _triple(act_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(act_uri, TG_COMPONENT_VERSION, _literal(component_version)), + + # Agent declaration + _triple(agt_uri, RDF_TYPE, _iri(PROV_AGENT)), + _triple(agt_uri, RDFS_LABEL, _literal(component_name)), + ]) + + if label: + triples.append(_triple(entity_uri, RDFS_LABEL, _literal(label))) + + if page_number is not None: + triples.append(_triple(entity_uri, TG_PAGE_NUMBER, _literal(page_number))) + + if chunk_index is not None: + triples.append(_triple(entity_uri, TG_CHUNK_INDEX, _literal(chunk_index))) + + if char_offset is not None: + triples.append(_triple(entity_uri, TG_CHAR_OFFSET, _literal(char_offset))) + + if char_length is not None: + triples.append(_triple(entity_uri, TG_CHAR_LENGTH, _literal(char_length))) + + if chunk_size is not None: + triples.append(_triple(act_uri, TG_CHUNK_SIZE, _literal(chunk_size))) + + if chunk_overlap is not None: + triples.append(_triple(act_uri, TG_CHUNK_OVERLAP, _literal(chunk_overlap))) + + return triples + + +def subgraph_provenance_triples( + subgraph_uri: str, + extracted_triples: List[Triple], + chunk_uri: str, + component_name: str, + component_version: str, + llm_model: Optional[str] = None, + ontology_uri: Optional[str] = None, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build provenance triples for a subgraph of extracted knowledge. + + One subgraph per chunk extraction, shared across all triples produced + from that chunk. This replaces per-triple reification with a + containment model. + + Creates: + - tg:contains link for each extracted triple (RDF-star quoted) + - One prov:wasDerivedFrom link to source chunk + - One activity with agent metadata + + Args: + subgraph_uri: URI for the extraction subgraph + extracted_triples: The extracted Triple objects to include + chunk_uri: URI of source chunk + component_name: Name of extractor component + component_version: Version of the component + llm_model: LLM model used for extraction + ontology_uri: Ontology URI used for extraction + timestamp: ISO timestamp + + Returns: + List of Triple objects for the provenance + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + act_uri = activity_uri() + agt_uri = agent_uri(component_name) + + triples = [] + + # Containment: subgraph tg:contains <> for each extracted triple + for extracted_triple in extracted_triples: + triple_term = Term(type=TRIPLE, triple=extracted_triple) + triples.append(Triple( + s=_iri(subgraph_uri), + p=_iri(TG_CONTAINS), + o=triple_term + )) + + # Subgraph provenance + triples.extend([ + _triple(subgraph_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(subgraph_uri, RDF_TYPE, _iri(TG_SUBGRAPH_TYPE)), + _triple(subgraph_uri, PROV_WAS_DERIVED_FROM, _iri(chunk_uri)), + _triple(subgraph_uri, PROV_WAS_GENERATED_BY, _iri(act_uri)), + + # Activity + _triple(act_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(act_uri, RDFS_LABEL, _literal(f"{component_name} extraction")), + _triple(act_uri, PROV_USED, _iri(chunk_uri)), + _triple(act_uri, PROV_WAS_ASSOCIATED_WITH, _iri(agt_uri)), + _triple(act_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(act_uri, TG_COMPONENT_VERSION, _literal(component_version)), + + # Agent + _triple(agt_uri, RDF_TYPE, _iri(PROV_AGENT)), + _triple(agt_uri, RDFS_LABEL, _literal(component_name)), + ]) + + if llm_model: + triples.append(_triple(act_uri, TG_LLM_MODEL, _literal(llm_model))) + + if ontology_uri: + triples.append(_triple(act_uri, TG_ONTOLOGY, _iri(ontology_uri))) + + return triples + + +# Query-time provenance triple builders +# +# Terminology: +# Question - What was asked, the anchor for everything +# Exploration - Casting wide, what do we know about this space +# Focus - Closing down, what's actually relevant here +# Synthesis - Weaving the relevant pieces into an answer + +def question_triples( + question_uri: str, + query: str, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a question activity. + + Creates: + - Activity declaration for the question + - Query text and timestamp + + Args: + question_uri: URI of the question (from question_uri) + query: The user's query text + timestamp: ISO timestamp (defaults to now) + + Returns: + List of Triple objects + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + return [ + _triple(question_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(question_uri, RDF_TYPE, _iri(TG_QUESTION)), + _triple(question_uri, RDF_TYPE, _iri(TG_GRAPH_RAG_QUESTION)), + _triple(question_uri, RDFS_LABEL, _literal("GraphRAG Question")), + _triple(question_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(question_uri, TG_QUERY, _literal(query)), + ] + + +def grounding_triples( + grounding_uri: str, + question_uri: str, + concepts: List[str], +) -> List[Triple]: + """ + Build triples for a grounding entity (concept decomposition of query). + + Creates: + - Entity declaration for grounding + - wasGeneratedBy link to question + - Concept literals for each extracted concept + + Args: + grounding_uri: URI of the grounding entity (from grounding_uri) + question_uri: URI of the parent question + concepts: List of concept strings extracted from the query + + Returns: + List of Triple objects + """ + triples = [ + _triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)), + _triple(grounding_uri, RDFS_LABEL, _literal("Grounding")), + _triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), + ] + + for concept in concepts: + triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept))) + + return triples + + +def exploration_triples( + exploration_uri: str, + grounding_uri: str, + edge_count: int, + entities: Optional[List[str]] = None, +) -> List[Triple]: + """ + Build triples for an exploration entity (all edges retrieved from subgraph). + + Creates: + - Entity declaration for exploration + - wasDerivedFrom link to grounding + - Edge count metadata + - Entity IRIs for each seed entity + + Args: + exploration_uri: URI of the exploration entity (from exploration_uri) + grounding_uri: URI of the parent grounding entity + edge_count: Number of edges retrieved + entities: Optional list of seed entity URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), + _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), + _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)), + _triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)), + ] + + if entities: + for entity in entities: + triples.append(_triple(exploration_uri, TG_ENTITY, _iri(entity))) + + return triples + + +def _quoted_triple(s: str, p: str, o: str) -> Term: + """Create a quoted triple term (RDF-star) from string values.""" + return Term( + type=TRIPLE, + triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o)) + ) + + +def focus_triples( + focus_uri: str, + exploration_uri: str, + selected_edges_with_reasoning: List[dict], + session_id: str = "", +) -> List[Triple]: + """ + Build triples for a focus entity (selected edges with reasoning). + + Creates: + - Entity declaration for focus + - wasDerivedFrom link to exploration + - For each selected edge: an edge selection entity with quoted triple and reasoning + + Structure: + tg:selectedEdge . + tg:edge <<

    >> . + tg:reasoning "reason" . + + Args: + focus_uri: URI of the focus entity (from focus_uri) + exploration_uri: URI of the parent exploration entity + selected_edges_with_reasoning: List of dicts with 'edge' (s,p,o tuple) and 'reasoning' + session_id: Session UUID for generating edge selection URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)), + _triple(focus_uri, RDFS_LABEL, _literal("Focus")), + _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), + ] + + # Add each selected edge with its reasoning via intermediate entity + for idx, edge_info in enumerate(selected_edges_with_reasoning): + edge = edge_info.get("edge") + reasoning = edge_info.get("reasoning", "") + + if edge: + s, p, o = edge + + # Create intermediate entity for this edge selection + edge_sel_uri = edge_selection_uri(session_id, idx) + + # Link focus to edge selection entity + triples.append( + _triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri)) + ) + + # Attach quoted triple to edge selection entity + quoted = _quoted_triple(s, p, o) + triples.append( + Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted) + ) + + # Attach reasoning to edge selection entity + if reasoning: + triples.append( + _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) + ) + + return triples + + +def synthesis_triples( + synthesis_uri: str, + focus_uri: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a synthesis entity (final answer). + + Creates: + - Entity declaration for synthesis with tg:Answer type + - wasDerivedFrom link to focus + - Document reference to librarian + + Args: + synthesis_uri: URI of the synthesis entity (from synthesis_uri) + focus_uri: URI of the parent focus entity + document_id: Librarian document ID for the answer content + + Returns: + List of Triple objects + """ + triples = [ + _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), + _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)), + ] + + if document_id: + triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + + return triples + + +# Document RAG provenance triple builders +# +# Document RAG uses a subset of GraphRAG's model: +# Question - What was asked +# Exploration - Chunks retrieved from document store +# Synthesis - The final answer (no Focus step) + +def docrag_question_triples( + question_uri: str, + query: str, + timestamp: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a document RAG question activity. + + Creates: + - Activity declaration with tg:Question type + - Query text and timestamp + + Args: + question_uri: URI of the question (from docrag_question_uri) + query: The user's query text + timestamp: ISO timestamp (defaults to now) + + Returns: + List of Triple objects + """ + if timestamp is None: + timestamp = datetime.utcnow().isoformat() + "Z" + + return [ + _triple(question_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(question_uri, RDF_TYPE, _iri(TG_QUESTION)), + _triple(question_uri, RDF_TYPE, _iri(TG_DOC_RAG_QUESTION)), + _triple(question_uri, RDFS_LABEL, _literal("DocumentRAG Question")), + _triple(question_uri, PROV_STARTED_AT_TIME, _literal(timestamp)), + _triple(question_uri, TG_QUERY, _literal(query)), + ] + + +def docrag_exploration_triples( + exploration_uri: str, + grounding_uri: str, + chunk_count: int, + chunk_ids: Optional[List[str]] = None, +) -> List[Triple]: + """ + Build triples for a document RAG exploration entity (chunks retrieved). + + Creates: + - Entity declaration with tg:Exploration type + - wasDerivedFrom link to grounding + - Chunk count and optional chunk references + + Args: + exploration_uri: URI of the exploration entity + grounding_uri: URI of the parent grounding entity + chunk_count: Number of chunks retrieved + chunk_ids: Optional list of chunk URIs/IDs + + Returns: + List of Triple objects + """ + triples = [ + _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), + _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), + _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)), + _triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)), + ] + + # Add references to selected chunks + if chunk_ids: + for chunk_id in chunk_ids: + triples.append(_triple(exploration_uri, TG_SELECTED_CHUNK, _iri(chunk_id))) + + return triples + + +def docrag_synthesis_triples( + synthesis_uri: str, + exploration_uri: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for a document RAG synthesis entity (final answer). + + Creates: + - Entity declaration with tg:Synthesis and tg:Answer types + - wasDerivedFrom link to exploration (skips focus step) + - Document reference to librarian + + Args: + synthesis_uri: URI of the synthesis entity + exploration_uri: URI of the parent exploration entity + document_id: Librarian document ID for the answer content + + Returns: + List of Triple objects + """ + triples = [ + _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), + _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), + ] + + if document_id: + triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + + return triples diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py new file mode 100644 index 00000000..670143df --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -0,0 +1,286 @@ +""" +URI generation for provenance entities. + +Document IDs are already IRIs (e.g., https://trustgraph.ai/doc/abc123). +Child entities (pages, chunks) append path segments to the parent IRI: +- Document: {doc_iri} (as provided) +- Page: {doc_iri}/p{page_number} +- Chunk: {page_iri}/c{chunk_index} (from page) + {doc_iri}/c{chunk_index} (from text doc) +- Activity: https://trustgraph.ai/activity/{uuid} +- Subgraph: https://trustgraph.ai/subgraph/{uuid} +""" + +import uuid +import urllib.parse + +# Base URI prefix for generated URIs (activities, statements, agents) +TRUSTGRAPH_BASE = "https://trustgraph.ai" + + +def _encode_id(id_str: str) -> str: + """URL-encode an ID component for safe inclusion in URIs.""" + return urllib.parse.quote(str(id_str), safe='') + + +def document_uri(doc_iri: str) -> str: + """Return the document IRI as-is (already a full URI).""" + return doc_iri + + +def page_uri(doc_iri: str, page_number: int) -> str: + """Generate URI for a page by appending to document IRI.""" + return f"{doc_iri}/p{page_number}" + + +def chunk_uri_from_page(doc_iri: str, page_number: int, chunk_index: int) -> str: + """Generate URI for a chunk extracted from a page.""" + return f"{doc_iri}/p{page_number}/c{chunk_index}" + + +def chunk_uri_from_doc(doc_iri: str, chunk_index: int) -> str: + """Generate URI for a chunk extracted directly from a text document.""" + return f"{doc_iri}/c{chunk_index}" + + +def activity_uri(activity_id: str = None) -> str: + """Generate URI for a PROV-O activity. Auto-generates UUID if not provided.""" + if activity_id is None: + activity_id = str(uuid.uuid4()) + return f"{TRUSTGRAPH_BASE}/activity/{_encode_id(activity_id)}" + + +def subgraph_uri(subgraph_id: str = None) -> str: + """Generate URI for an extraction subgraph. Auto-generates UUID if not provided.""" + if subgraph_id is None: + subgraph_id = str(uuid.uuid4()) + return f"{TRUSTGRAPH_BASE}/subgraph/{_encode_id(subgraph_id)}" + + +def agent_uri(component_name: str) -> str: + """Generate URI for a TrustGraph component agent.""" + return f"{TRUSTGRAPH_BASE}/agent/{_encode_id(component_name)}" + + +# Query-time provenance URIs +# These URIs use the urn:trustgraph: namespace to distinguish query-time +# provenance from extraction-time provenance (which uses https://trustgraph.ai/) +# +# Terminology: +# Question - What was asked, the anchor for everything +# Grounding - Decomposing the question into concepts +# Exploration - Casting wide, what do we know about this space +# Focus - Closing down, what's actually relevant here +# Synthesis - Weaving the relevant pieces into an answer + +def question_uri(session_id: str = None) -> str: + """ + Generate URI for a question activity. + + Args: + session_id: Optional UUID string. Auto-generates if not provided. + + Returns: + URN in format: urn:trustgraph:question:{uuid} + """ + if session_id is None: + session_id = str(uuid.uuid4()) + return f"urn:trustgraph:question:{session_id}" + + +def grounding_uri(session_id: str) -> str: + """ + Generate URI for a grounding entity (concept decomposition of query). + + Args: + session_id: The session UUID (same as question_uri). + + Returns: + URN in format: urn:trustgraph:prov:grounding:{uuid} + """ + return f"urn:trustgraph:prov:grounding:{session_id}" + + +def exploration_uri(session_id: str) -> str: + """ + Generate URI for an exploration entity (edges retrieved from subgraph). + + Args: + session_id: The session UUID (same as question_uri). + + Returns: + URN in format: urn:trustgraph:prov:exploration:{uuid} + """ + return f"urn:trustgraph:prov:exploration:{session_id}" + + +def focus_uri(session_id: str) -> str: + """ + Generate URI for a focus entity (selected edges with reasoning). + + Args: + session_id: The session UUID (same as question_uri). + + Returns: + URN in format: urn:trustgraph:prov:focus:{uuid} + """ + return f"urn:trustgraph:prov:focus:{session_id}" + + +def synthesis_uri(session_id: str) -> str: + """ + Generate URI for a synthesis entity (final answer text). + + Args: + session_id: The session UUID (same as question_uri). + + Returns: + URN in format: urn:trustgraph:prov:synthesis:{uuid} + """ + return f"urn:trustgraph:prov:synthesis:{session_id}" + + +def edge_selection_uri(session_id: str, edge_index: int) -> str: + """ + Generate URI for an edge selection item (links edge to reasoning). + + Args: + session_id: The session UUID. + edge_index: Index of this edge in the selection (0-based). + + Returns: + URN in format: urn:trustgraph:prov:edge:{uuid}:{index} + """ + return f"urn:trustgraph:prov:edge:{session_id}:{edge_index}" + + +# Agent provenance URIs +# These URIs use the urn:trustgraph:agent: namespace to distinguish agent +# provenance from GraphRAG question provenance + +def agent_session_uri(session_id: str = None) -> str: + """ + Generate URI for an agent session. + + Args: + session_id: Optional UUID string. Auto-generates if not provided. + + Returns: + URN in format: urn:trustgraph:agent:{uuid} + """ + if session_id is None: + session_id = str(uuid.uuid4()) + return f"urn:trustgraph:agent:{session_id}" + + +def agent_iteration_uri(session_id: str, iteration_num: int) -> str: + """ + Generate URI for an agent iteration. + + Args: + session_id: The session UUID. + iteration_num: 1-based iteration number. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/i{num} + """ + return f"urn:trustgraph:agent:{session_id}/i{iteration_num}" + + +def agent_thought_uri(session_id: str, iteration_num: int) -> str: + """ + Generate URI for an agent thought sub-entity. + + Args: + session_id: The session UUID. + iteration_num: 1-based iteration number. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/i{num}/thought + """ + return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + + +def agent_observation_uri(session_id: str, iteration_num: int) -> str: + """ + Generate URI for an agent observation sub-entity. + + Args: + session_id: The session UUID. + iteration_num: 1-based iteration number. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/i{num}/observation + """ + return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + + +def agent_final_uri(session_id: str) -> str: + """ + Generate URI for an agent final answer. + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/final + """ + return f"urn:trustgraph:agent:{session_id}/final" + + +# Document RAG provenance URIs +# These URIs use the urn:trustgraph:docrag: namespace to distinguish +# document RAG provenance from graph RAG provenance + +def docrag_question_uri(session_id: str = None) -> str: + """ + Generate URI for a document RAG question activity. + + Args: + session_id: Optional UUID string. Auto-generates if not provided. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid} + """ + if session_id is None: + session_id = str(uuid.uuid4()) + return f"urn:trustgraph:docrag:{session_id}" + + +def docrag_grounding_uri(session_id: str) -> str: + """ + Generate URI for a document RAG grounding entity (concept decomposition). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/grounding + """ + return f"urn:trustgraph:docrag:{session_id}/grounding" + + +def docrag_exploration_uri(session_id: str) -> str: + """ + Generate URI for a document RAG exploration entity (chunks retrieved). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/exploration + """ + return f"urn:trustgraph:docrag:{session_id}/exploration" + + +def docrag_synthesis_uri(session_id: str) -> str: + """ + Generate URI for a document RAG synthesis entity (final answer). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/synthesis + """ + return f"urn:trustgraph:docrag:{session_id}/synthesis" diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py new file mode 100644 index 00000000..018e2bfe --- /dev/null +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -0,0 +1,138 @@ +""" +Vocabulary bootstrap for provenance. + +The knowledge graph is ontology-neutral and initializes empty. When writing +PROV-O provenance data to a collection for the first time, the vocabulary +must be bootstrapped with RDF labels for all classes and predicates. +""" + +from typing import List + +from .. schema import Triple, Term, IRI, LITERAL + +from . namespaces import ( + RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + SCHEMA_DIGITAL_DOCUMENT, SCHEMA_DESCRIPTION, + SCHEMA_KEYWORDS, SCHEMA_NAME, + SKOS_DEFINITION, + TG_CONTAINS, TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER, + TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, + TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, + TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL, + TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH, + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + TG_CONCEPT, TG_ENTITY, TG_GROUNDING, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, +) + + +def _label_triple(uri: str, label: str) -> Triple: + """Create a label triple for a URI.""" + return Triple( + s=Term(type=IRI, iri=uri), + p=Term(type=IRI, iri=RDFS_LABEL), + o=Term(type=LITERAL, value=label), + ) + + +# PROV-O class labels +PROV_CLASS_LABELS = [ + _label_triple(PROV_ENTITY, "Entity"), + _label_triple(PROV_ACTIVITY, "Activity"), + _label_triple(PROV_AGENT, "Agent"), +] + +# PROV-O predicate labels +PROV_PREDICATE_LABELS = [ + _label_triple(PROV_WAS_DERIVED_FROM, "was derived from"), + _label_triple(PROV_WAS_GENERATED_BY, "was generated by"), + _label_triple(PROV_USED, "used"), + _label_triple(PROV_WAS_ASSOCIATED_WITH, "was associated with"), + _label_triple(PROV_STARTED_AT_TIME, "started at"), +] + +# Dublin Core predicate labels +DC_PREDICATE_LABELS = [ + _label_triple(DC_TITLE, "title"), + _label_triple(DC_SOURCE, "source"), + _label_triple(DC_DATE, "date"), + _label_triple(DC_CREATOR, "creator"), +] + +# Schema.org labels +SCHEMA_LABELS = [ + _label_triple(SCHEMA_DIGITAL_DOCUMENT, "Digital Document"), + _label_triple(SCHEMA_DESCRIPTION, "description"), + _label_triple(SCHEMA_KEYWORDS, "keywords"), + _label_triple(SCHEMA_NAME, "name"), +] + +# SKOS labels +SKOS_LABELS = [ + _label_triple(SKOS_DEFINITION, "definition"), +] + +# TrustGraph class labels (extraction provenance) +TG_CLASS_LABELS = [ + _label_triple(TG_DOCUMENT_TYPE, "Document"), + _label_triple(TG_PAGE_TYPE, "Page"), + _label_triple(TG_CHUNK_TYPE, "Chunk"), + _label_triple(TG_SUBGRAPH_TYPE, "Subgraph"), + _label_triple(TG_GROUNDING, "Grounding"), + _label_triple(TG_ANSWER_TYPE, "Answer"), + _label_triple(TG_REFLECTION_TYPE, "Reflection"), + _label_triple(TG_THOUGHT_TYPE, "Thought"), + _label_triple(TG_OBSERVATION_TYPE, "Observation"), +] + +# TrustGraph predicate labels +TG_PREDICATE_LABELS = [ + _label_triple(TG_CONTAINS, "contains"), + _label_triple(TG_PAGE_COUNT, "page count"), + _label_triple(TG_MIME_TYPE, "MIME type"), + _label_triple(TG_PAGE_NUMBER, "page number"), + _label_triple(TG_CHUNK_INDEX, "chunk index"), + _label_triple(TG_CHAR_OFFSET, "character offset"), + _label_triple(TG_CHAR_LENGTH, "character length"), + _label_triple(TG_CHUNK_SIZE, "chunk size"), + _label_triple(TG_CHUNK_OVERLAP, "chunk overlap"), + _label_triple(TG_COMPONENT_VERSION, "component version"), + _label_triple(TG_LLM_MODEL, "LLM model"), + _label_triple(TG_ONTOLOGY, "ontology"), + _label_triple(TG_EMBEDDING_MODEL, "embedding model"), + _label_triple(TG_SOURCE_TEXT, "source text"), + _label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"), + _label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"), + _label_triple(TG_CONCEPT, "concept"), + _label_triple(TG_ENTITY, "entity"), +] + + +def get_vocabulary_triples() -> List[Triple]: + """ + Get all vocabulary bootstrap triples. + + Returns a list of triples that define labels for all PROV-O classes, + PROV-O predicates, Dublin Core predicates, and TrustGraph predicates + used in extraction-time provenance. + + This should be emitted to the knowledge graph once per collection + before any provenance data is written. The operation is idempotent - + re-emitting the same triples is harmless. + + Returns: + List of Triple objects defining vocabulary labels + """ + return ( + PROV_CLASS_LABELS + + PROV_PREDICATE_LABELS + + DC_PREDICATE_LABELS + + SCHEMA_LABELS + + SKOS_LABELS + + TG_CLASS_LABELS + + TG_PREDICATE_LABELS + ) diff --git a/trustgraph-base/trustgraph/rdf.py b/trustgraph-base/trustgraph/rdf.py index 32799b8d..1d3b7cba 100644 --- a/trustgraph-base/trustgraph/rdf.py +++ b/trustgraph-base/trustgraph/rdf.py @@ -2,7 +2,6 @@ RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" RDF_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" DEFINITION = "http://www.w3.org/2004/02/skos/core#definition" -SUBJECT_OF = "https://schema.org/subjectOf" TRUSTGRAPH_ENTITIES = "http://trustgraph.ai/e/" diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index 1888e612..a37a8d62 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -1,13 +1,12 @@ -from dataclasses import dataclass, field -from .primitives import Triple +from dataclasses import dataclass @dataclass class Metadata: # Source identifier id: str = "" - # Subgraph - metadata: list[Triple] = field(default_factory=list) + # Root document identifier (set by librarian, preserved through pipeline) + root: str = "" # Collection management user: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/document.py b/trustgraph-base/trustgraph/schema/knowledge/document.py index d8ce97b4..c75a1227 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/document.py +++ b/trustgraph-base/trustgraph/schema/knowledge/document.py @@ -10,6 +10,9 @@ from ..core.topic import topic class Document: metadata: Metadata | None = None data: bytes = b"" + # For large document streaming: if document_id is set, the receiver should + # fetch content from librarian instead of using inline data + document_id: str = "" ############################################################################ @@ -19,6 +22,9 @@ class Document: class TextDocument: metadata: Metadata | None = None text: bytes = b"" + # For large document streaming: if document_id is set, the receiver should + # fetch content from librarian instead of using inline text + document_id: str = "" ############################################################################ @@ -28,5 +34,9 @@ class TextDocument: class Chunk: metadata: Metadata | None = None chunk: bytes = b"" + # For provenance: document_id of this chunk in librarian + # Post-chunker optimization: both document_id AND chunk content are included + # so downstream processors have the ID for provenance and content to work with + document_id: str = "" ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index 93559056..a8bae35c 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -11,7 +11,9 @@ from ..core.topic import topic @dataclass class EntityEmbeddings: entity: Term | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) + # Provenance: which chunk this embedding was derived from + chunk_id: str = "" # This is a 'batching' mechanism for the above data @dataclass @@ -25,8 +27,8 @@ class GraphEmbeddings: @dataclass class ChunkEmbeddings: - chunk: bytes = b"" - vectors: list[list[float]] = field(default_factory=list) + chunk_id: str = "" + vector: list[float] = field(default_factory=list) # This is a 'batching' mechanism for the above data @dataclass @@ -42,7 +44,7 @@ class DocumentEmbeddings: @dataclass class ObjectEmbeddings: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) name: str = "" key_name: str = "" id: str = "" @@ -54,7 +56,7 @@ class ObjectEmbeddings: @dataclass class StructuredObjectEmbedding: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) schema_name: str = "" object_id: str = "" # Primary key value field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings @@ -70,7 +72,7 @@ class RowIndexEmbedding: index_name: str = "" # The indexed field name(s) index_value: list[str] = field(default_factory=list) # The field value(s) text: str = "" # Text that was embedded - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) @dataclass class RowEmbeddings: diff --git a/trustgraph-base/trustgraph/schema/knowledge/graph.py b/trustgraph-base/trustgraph/schema/knowledge/graph.py index 4ee8d2c0..b4a05084 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/graph.py +++ b/trustgraph-base/trustgraph/schema/knowledge/graph.py @@ -12,6 +12,8 @@ from ..core.topic import topic class EntityContext: entity: Term | None = None context: str = "" + # Provenance: which chunk this entity context was derived from + chunk_id: str = "" # This is a 'batching' mechanism for the above data @dataclass diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 7b40ca0a..f246bc31 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -12,4 +12,5 @@ from .structured_query import * from .rows_query import * from .diagnosis import * from .collection import * -from .storage import * \ No newline at end of file +from .storage import * +from .tool_service import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 9f883ff2..91179047 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -23,16 +23,22 @@ class AgentRequest: group: list[str] | None = None history: list[AgentStep] = field(default_factory=list) user: str = "" # User context for multi-tenancy - streaming: bool = False # NEW: Enable streaming response delivery (default false) + collection: str = "default" # Collection for provenance traces + streaming: bool = False # Enable streaming response delivery (default false) + session_id: str = "" # For provenance tracking across iterations @dataclass class AgentResponse: # Streaming-first design - chunk_type: str = "" # "thought", "action", "observation", "answer", "error" + chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" content: str = "" # The actual content (interpretation depends on chunk_type) end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_dialog: bool = False # Entire agent dialog is complete + # Explainability fields + explain_id: str | None = None # Provenance URI (announced as created) + explain_graph: str | None = None # Named graph where explain was stored + # Legacy fields (deprecated but kept for backward compatibility) answer: str = "" error: Error | None = None diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index 391d49e1..f1ab360f 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -49,6 +49,36 @@ from ..core.metadata import Metadata # <- (processing_metadata[]) # <- (error) +# begin-upload +# -> (document_metadata, total_size, chunk_size) +# <- (upload_id, chunk_size, total_chunks) +# <- (error) + +# upload-chunk +# -> (upload_id, chunk_index, content) +# <- (upload_id, chunk_index, chunks_received, total_chunks, bytes_received, total_bytes) +# <- (error) + +# complete-upload +# -> (upload_id) +# <- (document_id, object_id) +# <- (error) + +# abort-upload +# -> (upload_id) +# <- () +# <- (error) + +# get-upload-status +# -> (upload_id) +# <- (upload_id, state, chunks_received, missing_chunks, total_chunks, bytes_received, total_bytes) +# <- (error) + +# list-uploads +# -> (user) +# <- (uploads[]) +# <- (error) + @dataclass class DocumentMetadata: id: str = "" @@ -59,6 +89,14 @@ class DocumentMetadata: metadata: list[Triple] = field(default_factory=list) user: str = "" tags: list[str] = field(default_factory=list) + # Child document support + parent_id: str = "" # Empty for top-level docs, set for children + # Document type vocabulary: + # "source" - original uploaded document + # "page" - page extracted from source (e.g., PDF page) + # "chunk" - text chunk derived from page or source + # "extracted" - legacy value, kept for backwards compatibility + document_type: str = "source" @dataclass class ProcessingMetadata: @@ -76,11 +114,33 @@ class Criteria: value: str = "" operator: str = "" +@dataclass +class UploadProgress: + """Progress information for chunked uploads.""" + upload_id: str = "" + chunks_received: int = 0 + total_chunks: int = 0 + bytes_received: int = 0 + total_bytes: int = 0 + +@dataclass +class UploadSession: + """Information about an in-progress upload.""" + upload_id: str = "" + document_id: str = "" + document_metadata_json: str = "" # JSON-encoded DocumentMetadata + total_size: int = 0 + chunk_size: int = 0 + total_chunks: int = 0 + chunks_received: int = 0 + created_at: str = "" + @dataclass class LibrarianRequest: # add-document, remove-document, update-document, get-document-metadata, # get-document-content, add-processing, remove-processing, list-documents, - # list-processing + # list-processing, begin-upload, upload-chunk, complete-upload, abort-upload, + # get-upload-status, list-uploads operation: str = "" # add-document, remove-document, update-document, get-document-metadata, @@ -90,16 +150,16 @@ class LibrarianRequest: # add-processing, remove-processing processing_id: str = "" - # add-document, update-document + # add-document, update-document, begin-upload document_metadata: DocumentMetadata | None = None # add-processing processing_metadata: ProcessingMetadata | None = None - # add-document + # add-document, upload-chunk content: bytes = b"" - # list-documents, list-processing + # list-documents, list-processing, list-uploads user: str = "" # list-documents?, list-processing? @@ -108,6 +168,19 @@ class LibrarianRequest: # criteria: list[Criteria] = field(default_factory=list) + # begin-upload + total_size: int = 0 + chunk_size: int = 0 + + # upload-chunk, complete-upload, abort-upload, get-upload-status + upload_id: str = "" + + # upload-chunk, stream-document + chunk_index: int = 0 + + # list-documents - whether to include child documents (default False) + include_children: bool = False + @dataclass class LibrarianResponse: error: Error | None = None @@ -116,6 +189,34 @@ class LibrarianResponse: document_metadatas: list[DocumentMetadata] = field(default_factory=list) processing_metadatas: list[ProcessingMetadata] = field(default_factory=list) + # begin-upload response + upload_id: str = "" + chunk_size: int = 0 + total_chunks: int = 0 + + # upload-chunk response + chunk_index: int = 0 + chunks_received: int = 0 + bytes_received: int = 0 + total_bytes: int = 0 + + # complete-upload response + document_id: str = "" + object_id: str = "" + + # get-upload-status response + upload_state: str = "" # "in-progress", "completed", "expired" + received_chunks: list[int] = field(default_factory=list) + missing_chunks: list[int] = field(default_factory=list) + + # list-uploads response + upload_sessions: list[UploadSession] = field(default_factory=list) + + # Protocol flag: True if this is the final response for a request. + # Default True since most operations are single request/response. + # Only stream-document sets False for intermediate chunks. + is_final: bool = True + # FIXME: Is this right? Using persistence on librarian so that # message chunking works diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 1261158e..681638c3 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -29,7 +29,7 @@ class TextCompletionResponse: @dataclass class EmbeddingsRequest: - text: str = "" + texts: list[str] = field(default_factory=list) @dataclass class EmbeddingsResponse: diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 50ec416a..7a65f775 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -9,15 +9,21 @@ from ..core.topic import topic @dataclass class GraphEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class EntityMatch: + """A matching entity from a semantic search with similarity score""" + entity: Term | None = None + score: float = 0.0 + @dataclass class GraphEmbeddingsResponse: error: Error | None = None - entities: list[Term] = field(default_factory=list) + entities: list[EntityMatch] = field(default_factory=list) ############################################################################ @@ -32,11 +38,14 @@ class TriplesQueryRequest: o: Term | None = None g: str | None = None # Graph IRI. None=default graph, "*"=all graphs limit: int = 0 + streaming: bool = False # Enable streaming mode (multiple batched responses) + batch_size: int = 20 # Triples per batch in streaming mode @dataclass class TriplesQueryResponse: error: Error | None = None triples: list[Triple] = field(default_factory=list) + is_final: bool = True # False for intermediate batches in streaming mode ############################################################################ @@ -44,15 +53,21 @@ class TriplesQueryResponse: @dataclass class DocumentEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class ChunkMatch: + """A matching chunk from a semantic search with similarity score""" + chunk_id: str = "" + score: float = 0.0 + @dataclass class DocumentEmbeddingsResponse: error: Error | None = None - chunks: list[str] = field(default_factory=list) + chunks: list[ChunkMatch] = field(default_factory=list) document_embeddings_request_queue = topic( "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' @@ -76,7 +91,7 @@ class RowIndexMatch: @dataclass class RowEmbeddingsRequest: """Request for row embeddings semantic search""" - vectors: list[list[float]] = field(default_factory=list) # Query vectors + vector: list[float] = field(default_factory=list) # Query vector limit: int = 10 # Max results to return user: str = "" # User/keyspace collection: str = "" # Collection name diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 4337cb9b..b3a9d58d 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -15,13 +15,18 @@ class GraphRagQuery: triple_limit: int = 0 max_subgraph_size: int = 0 max_path_length: int = 0 + edge_limit: int = 0 streaming: bool = False @dataclass class GraphRagResponse: error: Error | None = None response: str = "" - end_of_stream: bool = False + end_of_stream: bool = False # LLM response stream complete + explain_id: str | None = None # Single explain URI (announced as created) + explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) + message_type: str = "" # "chunk" or "explain" + end_of_session: bool = False # Entire session complete ############################################################################ @@ -38,5 +43,9 @@ class DocumentRagQuery: @dataclass class DocumentRagResponse: error: Error | None = None - response: str = "" - end_of_stream: bool = False + response: str | None = "" + end_of_stream: bool = False # LLM response stream complete + explain_id: str | None = None # Single explain URI (announced as created) + explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) + message_type: str = "" # "chunk" or "explain" + end_of_session: bool = False # Entire session complete diff --git a/trustgraph-base/trustgraph/schema/services/tool_service.py b/trustgraph-base/trustgraph/schema/services/tool_service.py new file mode 100644 index 00000000..18315f29 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/tool_service.py @@ -0,0 +1,25 @@ + +from dataclasses import dataclass + +from ..core.primitives import Error + + +@dataclass +class ToolServiceRequest: + """Request to a dynamically configured tool service.""" + # User context for multi-tenancy + user: str = "" + # Config values (collection, etc.) as JSON + config: str = "" + # Arguments from LLM as JSON + arguments: str = "" + + +@dataclass +class ToolServiceResponse: + """Response from a tool service.""" + error: Error | None = None + # Response text (the observation) + response: str = "" + # End of stream marker for streaming responses + end_of_stream: bool = False diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 4e093953..d8a55f3d 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 66df74f1..fb3402c9 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", "requests", "pulsar-client", "aiohttp", @@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" +tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" tg-init-trustgraph = "trustgraph.cli.init_trustgraph:main" tg-invoke-agent = "trustgraph.cli.invoke_agent:main" @@ -54,9 +55,7 @@ tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" tg-load-kg-core = "trustgraph.cli.load_kg_core:main" -tg-load-pdf = "trustgraph.cli.load_pdf:main" tg-load-sample-documents = "trustgraph.cli.load_sample_documents:main" -tg-load-text = "trustgraph.cli.load_text:main" tg-load-turtle = "trustgraph.cli.load_turtle:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main" @@ -72,6 +71,7 @@ tg-show-config = "trustgraph.cli.show_config:main" tg-show-flow-blueprints = "trustgraph.cli.show_flow_blueprints:main" tg-show-flow-state = "trustgraph.cli.show_flow_state:main" tg-show-flows = "trustgraph.cli.show_flows:main" +tg-query-graph = "trustgraph.cli.query_graph:main" tg-show-graph = "trustgraph.cli.show_graph:main" tg-show-kg-cores = "trustgraph.cli.show_kg_cores:main" tg-show-library-documents = "trustgraph.cli.show_library_documents:main" @@ -96,6 +96,9 @@ tg-delete-config-item = "trustgraph.cli.delete_config_item:main" tg-list-collections = "trustgraph.cli.list_collections:main" tg-set-collection = "trustgraph.cli.set_collection:main" tg-delete-collection = "trustgraph.cli.delete_collection:main" +tg-show-extraction-provenance = "trustgraph.cli.show_extraction_provenance:main" +tg-list-explain-traces = "trustgraph.cli.list_explain_traces:main" +tg-show-explain-trace = "trustgraph.cli.show_explain_trace:main" [tool.setuptools.packages.find] include = ["trustgraph*"] diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py new file mode 100644 index 00000000..3d70f37d --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/get_document_content.py @@ -0,0 +1,87 @@ +""" +Gets document content from the library by document ID. +""" + +import argparse +import os +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_user = "trustgraph" + +def get_content(url, user, document_id, output_file, token=None): + + api = Api(url, token=token).library() + + content = api.get_document_content(user=user, id=document_id) + + if output_file: + with open(output_file, 'wb') as f: + f.write(content) + print(f"Written {len(content)} bytes to {output_file}") + else: + # Write to stdout + # Try to decode as text, fall back to binary info + try: + text = content.decode('utf-8') + print(text) + except UnicodeDecodeError: + print(f"Binary content: {len(content)} bytes", file=sys.stderr) + sys.stdout.buffer.write(content) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-get-document-content', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-o', '--output', + default=None, + help='Output file (default: stdout)' + ) + + parser.add_argument( + 'document_id', + help='Document ID (IRI) to retrieve', + ) + + args = parser.parse_args() + + try: + + get_content( + url=args.api_url, + user=args.user, + document_id=args.document_id, + output_file=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py index 1d34e39f..840f8574 100644 --- a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py +++ b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py @@ -1,6 +1,7 @@ """ Connects to the graph query service and dumps all graph edges in Turtle -format. +format with RDF-star support for quoted triples. +Uses streaming mode for lower time-to-first-processing. """ import rdflib @@ -9,48 +10,82 @@ import sys import argparse import os -from trustgraph.api import Api, Uri +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_graph(url, flow_id, user, collection): - api = Api(url).flow().id(flow_id) +def term_to_rdflib(term): + """Convert a wire-format term to an rdflib term.""" + if term is None: + return None - rows = api.triples_query( - s=None, p=None, o=None, - user=user, collection=collection, - limit=10_000) + t = term.get("t", "") + + if t == "i": # IRI + iri = term.get("i", "") + # Skip malformed URLs with spaces + if " " in iri: + return None + return rdflib.term.URIRef(iri) + elif t == "l": # Literal + value = term.get("v", "") + datatype = term.get("d") + language = term.get("l") + if language: + return rdflib.term.Literal(value, lang=language) + elif datatype: + return rdflib.term.Literal(value, datatype=rdflib.term.URIRef(datatype)) + else: + return rdflib.term.Literal(value) + elif t == "r": # Quoted triple (RDF-star) + triple = term.get("r", {}) + s_term = term_to_rdflib(triple.get("s")) + p_term = term_to_rdflib(triple.get("p")) + o_term = term_to_rdflib(triple.get("o")) + if s_term is None or p_term is None or o_term is None: + return None + try: + return rdflib.term.Triple((s_term, p_term, o_term)) + except AttributeError: + # Fallback for older rdflib versions + return rdflib.term.Literal(f"<<{s_term} {p_term} {o_term}>>") + else: + # Fallback + return rdflib.term.Literal(str(term)) + + +def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): + + socket = Api(url, token=token).socket() + flow = socket.flow(flow_id) g = rdflib.Graph() - for row in rows: + try: + for batch in flow.triples_query_stream( + s=None, p=None, o=None, + user=user, collection=collection, + limit=limit, + batch_size=batch_size, + ): + for triple in batch: + sv = term_to_rdflib(triple.get("s")) + pv = term_to_rdflib(triple.get("p")) + ov = term_to_rdflib(triple.get("o")) - sv = rdflib.term.URIRef(row.s) - pv = rdflib.term.URIRef(row.p) + if sv is None or pv is None or ov is None: + continue - if isinstance(row.o, Uri): - - # Skip malformed URLs with spaces in - if " " in row.o: - continue - - ov = rdflib.term.URIRef(row.o) - - else: - - ov = rdflib.term.Literal(row.o) - - g.add((sv, pv, ov)) - - g.serialize(destination="output.ttl", format="turtle") + g.add((sv, pv, ov)) + finally: + socket.close() buf = io.BytesIO() - g.serialize(destination=buf, format="turtle") - sys.stdout.write(buf.getvalue().decode("utf-8")) @@ -85,6 +120,26 @@ def main(): help=f'Collection ID (default: {default_collection})' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Maximum number of triples to return (default: 10000)', + ) + + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Triples per streaming batch (default: 20)', + ) + args = parser.parse_args() try: @@ -94,6 +149,9 @@ def main(): flow_id = args.flow_id, user = args.user, collection = args.collection, + limit = args.limit, + batch_size = args.batch_size, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 369fcdd4..9879025f 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -4,8 +4,19 @@ Uses the agent service to answer a question import argparse import os +import sys import textwrap -from trustgraph.api import Api +from trustgraph.api import ( + Api, + ExplainabilityClient, + ProvenanceEvent, + Question, + Analysis, + Conclusion, + AgentThought, + AgentObservation, + AgentAnswer, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -97,11 +108,149 @@ def output(text, prefix="> ", width=78): ) print(out) +def question_explainable( + url, question_text, flow_id, user, collection, + state=None, group=None, verbose=False, token=None, debug=False +): + """Execute agent with explainability - shows provenance events inline.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Track last chunk type for formatting + last_chunk_type = None + current_outputter = None + + # Stream agent with explainability - process events as they arrive + for item in flow.agent_explain( + question=question_text, + user=user, + collection=collection, + state=state, + group=group, + ): + if isinstance(item, AgentThought): + if last_chunk_type != "thought": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() # Blank line between message types + if verbose: + current_outputter = Outputter(width=78, prefix="\U0001f914 ") + current_outputter.__enter__() + last_chunk_type = "thought" + if current_outputter: + current_outputter.output(item.content) + if current_outputter.word_buffer: + print(current_outputter.word_buffer, end="", flush=True) + current_outputter.column += len(current_outputter.word_buffer) + current_outputter.word_buffer = "" + + elif isinstance(item, AgentObservation): + if last_chunk_type != "observation": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() + if verbose: + current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") + current_outputter.__enter__() + last_chunk_type = "observation" + if current_outputter: + current_outputter.output(item.content) + if current_outputter.word_buffer: + print(current_outputter.word_buffer, end="", flush=True) + current_outputter.column += len(current_outputter.word_buffer) + current_outputter.word_buffer = "" + + elif isinstance(item, AgentAnswer): + if last_chunk_type != "answer": + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + print() + last_chunk_type = "answer" + # Print answer content directly + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [session] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Analysis): + print(f"\n [iteration] {prov_id}", file=sys.stderr) + if entity.action: + print(f" Action: {entity.action}", file=sys.stderr) + if entity.thought: + print(f" Thought: {entity.thought}", file=sys.stderr) + if entity.observation: + print(f" Observation: {entity.observation}", file=sys.stderr) + + elif isinstance(entity, Conclusion): + print(f"\n [conclusion] {prov_id}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + # Close any remaining outputter + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + + # Final newline if we ended with answer + if last_chunk_type == "answer": + print() + + finally: + socket.close() + + def question( url, question, flow_id, user, collection, plan=None, state=None, group=None, verbose=False, streaming=True, - token=None + token=None, explainable=False, debug=False ): + # Explainable mode uses the API to capture and process provenance events + if explainable: + question_explainable( + url=url, + question_text=question, + flow_id=flow_id, + user=user, + collection=collection, + state=state, + group=group, + verbose=verbose, + token=token, + debug=debug + ) + return if verbose: output(wrap(question), "\U00002753 ") @@ -270,6 +419,18 @@ def main(): help=f'Disable streaming (use legacy mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events: Session, Iterations, Conclusion (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -286,6 +447,8 @@ def main(): verbose = args.verbose, streaming = not args.no_streaming, token = args.token, + explainable = args.explainable, + debug = args.debug, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py index b14397cb..43bcc985 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -1,6 +1,6 @@ """ Queries document chunks by text similarity using vector embeddings. -Returns a list of matching document chunks, truncated to the specified length. +Returns a list of matching chunk IDs. """ import argparse @@ -10,13 +10,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def truncate_chunk(chunk, max_length): - """Truncate a chunk to max_length characters, adding ellipsis if needed.""" - if len(chunk) <= max_length: - return chunk - return chunk[:max_length] + "..." - -def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None): +def query(url, flow_id, query_text, user, collection, limit, token=None): # Create API client api = Api(url=url, token=token) @@ -33,9 +27,13 @@ def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, t ) chunks = result.get("chunks", []) - for i, chunk in enumerate(chunks, 1): - truncated = truncate_chunk(chunk, max_chunk_length) - print(f"{i}. {truncated}") + if not chunks: + print("No matching chunks found.") + else: + for i, chunk in enumerate(chunks, 1): + chunk_id = chunk.get("chunk_id", "") + score = chunk.get("score", 0.0) + print(f"{i}. {chunk_id} (score: {score:.4f})") finally: # Clean up socket connection @@ -85,13 +83,6 @@ def main(): help='Maximum number of results (default: 10)', ) - parser.add_argument( - '--max-chunk-length', - type=int, - default=200, - help='Truncate chunks to N characters (default: 200)', - ) - parser.add_argument( 'query', nargs=1, @@ -109,7 +100,6 @@ def main(): user=args.user, collection=args.collection, limit=args.limit, - max_chunk_length=args.max_chunk_length, token=args.token, ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 7e88bdc4..7da9d779 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -4,7 +4,17 @@ Uses the DocumentRAG service to answer a question import argparse import os -from trustgraph.api import Api +import sys +from trustgraph.api import ( + Api, + ExplainabilityClient, + RAGChunk, + ProvenanceEvent, + Question, + Grounding, + Exploration, + Synthesis, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -12,7 +22,96 @@ default_user = 'trustgraph' default_collection = 'default' default_doc_limit = 10 -def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None): + +def question_explainable( + url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False +): + """Execute document RAG with explainability - shows provenance events inline.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Stream DocumentRAG with explainability - process events as they arrive + for item in flow.document_rag_explain( + query=question_text, + user=user, + collection=collection, + doc_limit=doc_limit, + ): + if isinstance(item, RAGChunk): + # Print response content + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [question] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Grounding): + print(f"\n [grounding] {prov_id}", file=sys.stderr) + if entity.concepts: + for concept in entity.concepts: + print(f" Concept: {concept}", file=sys.stderr) + + elif isinstance(entity, Exploration): + print(f"\n [exploration] {prov_id}", file=sys.stderr) + if entity.chunk_count: + print(f" Chunks retrieved: {entity.chunk_count}", file=sys.stderr) + + elif isinstance(entity, Synthesis): + print(f"\n [synthesis] {prov_id}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + print() # Final newline + + finally: + socket.close() + + +def question( + url, flow_id, question_text, user, collection, doc_limit, + streaming=True, token=None, explainable=False, debug=False +): + # Explainable mode uses the API to capture and process provenance events + if explainable: + question_explainable( + url=url, + flow_id=flow_id, + question_text=question_text, + user=user, + collection=collection, + doc_limit=doc_limit, + token=token, + debug=debug + ) + return # Create API client api = Api(url=url, token=token) @@ -24,7 +123,7 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True try: response = flow.document_rag( - query=question, + query=question_text, user=user, collection=collection, doc_limit=doc_limit, @@ -42,13 +141,14 @@ def question(url, flow_id, question, user, collection, doc_limit, streaming=True # Use REST API for non-streaming flow = api.flow().id(flow_id) resp = flow.document_rag( - query=question, + query=question_text, user=user, collection=collection, doc_limit=doc_limit, ) print(resp) + def main(): parser = argparse.ArgumentParser( @@ -105,6 +205,18 @@ def main(): help='Disable streaming (use non-streaming mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events: Question, Exploration, Synthesis (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -112,12 +224,14 @@ def main(): question( url=args.url, flow_id=args.flow_id, - question=args.question, + question_text=args.question, user=args.user, collection=args.collection, doc_limit=args.doc_limit, streaming=not args.no_streaming, token=args.token, + explainable=args.explainable, + debug=args.debug, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py index 71a88bd7..699a85cf 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -10,7 +10,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, text, token=None): +def query(url, flow_id, texts, token=None): # Create API client api = Api(url=url, token=token) @@ -19,9 +19,14 @@ def query(url, flow_id, text, token=None): try: # Call embeddings service - result = flow.embeddings(text=text) + result = flow.embeddings(texts=texts) vectors = result.get("vectors", []) - print(vectors) + # Print each text's vectors + for i, vecs in enumerate(vectors): + if len(texts) > 1: + print(f"Text {i + 1}: {vecs}") + else: + print(vecs) finally: # Clean up socket connection @@ -53,9 +58,9 @@ def main(): ) parser.add_argument( - 'text', - nargs=1, - help='Text to convert to embedding vector', + 'texts', + nargs='+', + help='Text(s) to convert to embedding vectors', ) args = parser.parse_args() @@ -65,7 +70,7 @@ def main(): query( url=args.url, flow_id=args.flow_id, - text=args.text[0], + texts=args.texts, token=args.token, ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py index ae195007..5b0f4c67 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py @@ -27,8 +27,23 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): ) entities = result.get("entities", []) - for entity in entities: - print(entity) + if not entities: + print("No matching entities found.") + else: + for i, match in enumerate(entities, 1): + entity = match.get("entity", {}) + score = match.get("score", 0.0) + # Format entity based on type (wire format uses compact keys) + term_type = entity.get("t", "") + if term_type == "i": # IRI + entity_str = entity.get("i", "") + elif term_type == "l": # Literal + entity_str = f'"{entity.get("v", "")}"' + elif term_type == "b": # Blank node + entity_str = f'_:{entity.get("d", "")}' + else: + entity_str = str(entity) + print(f"{i}. {entity_str} (score: {score:.4f})") finally: # Clean up socket connection diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 5fa359ab..1e530c03 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,8 +3,22 @@ Uses the GraphRAG service to answer a question """ import argparse +import json import os -from trustgraph.api import Api +import sys +import websockets +import asyncio +from trustgraph.api import ( + Api, + ExplainabilityClient, + RAGChunk, + ProvenanceEvent, + Question, + Grounding, + Exploration, + Focus, + Synthesis, +) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -15,11 +29,741 @@ default_triple_limit = 30 default_max_subgraph_size = 150 default_max_path_length = 2 +# Provenance predicates +TG = "https://trustgraph.ai/ns/" +TG_QUERY = TG + "query" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" +TG_EDGE_COUNT = TG + "edgeCount" +TG_SELECTED_EDGE = TG + "selectedEdge" +TG_EDGE = TG + "edge" +TG_REASONING = TG + "reasoning" +TG_DOCUMENT = TG + "document" +TG_CONTAINS = TG + "contains" +PROV = "http://www.w3.org/ns/prov#" +PROV_STARTED_AT_TIME = PROV + "startedAtTime" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + +def _get_event_type(prov_id): + """Extract event type from provenance_id""" + if "question" in prov_id: + return "question" + elif "grounding" in prov_id: + return "grounding" + elif "exploration" in prov_id: + return "exploration" + elif "focus" in prov_id: + return "focus" + elif "synthesis" in prov_id: + return "synthesis" + return "provenance" + + +def _format_provenance_details(event_type, triples): + """Format provenance details based on event type and triples""" + lines = [] + + if event_type == "question": + # Show query and timestamp + for s, p, o in triples: + if p == TG_QUERY: + lines.append(f" Query: {o}") + elif p == PROV_STARTED_AT_TIME: + lines.append(f" Time: {o}") + + elif event_type == "grounding": + # Show extracted concepts + concepts = [o for s, p, o in triples if p == TG_CONCEPT] + if concepts: + lines.append(f" Concepts: {len(concepts)}") + for concept in concepts: + lines.append(f" - {concept}") + + elif event_type == "exploration": + # Show edge count (seed entities resolved separately with labels) + for s, p, o in triples: + if p == TG_EDGE_COUNT: + lines.append(f" Edges explored: {o}") + + elif event_type == "focus": + # For focus, just count edge selection URIs + # The actual edge details are fetched separately via edge_selections parameter + edge_sel_uris = [] + for s, p, o in triples: + if p == TG_SELECTED_EDGE: + edge_sel_uris.append(o) + if edge_sel_uris: + lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") + + elif event_type == "synthesis": + # Show document reference (content already streamed) + for s, p, o in triples: + if p == TG_DOCUMENT: + lines.append(f" Document: {o}") + + return lines + + +async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=None, debug=False): + """Query triples for a provenance node (single attempt)""" + request = { + "id": "triples-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": prov_id}, + "user": user, + "collection": collection, + "limit": 100 + } + } + # Add graph filter if specified (for named graph queries) + if graph is not None: + request["request"]["g"] = graph + + if debug: + print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) + + triples = [] + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if debug: + print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) + + if response.get("id") != "triples-request": + continue + + if "error" in response: + if debug: + print(f" [debug] error: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + # Handle triples response + # Response format: {"response": [triples...]} + # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type + triple_list = resp.get("response", []) + for t in triple_list: + s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) + p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) + # Handle quoted triples (type "t") and regular values + o_term = t.get("o", {}) + if o_term.get("t") == "t": + # Quoted triple - extract s, p, o from nested structure + tr = o_term.get("tr", {}) + o = { + "s": tr.get("s", {}).get("i", ""), + "p": tr.get("p", {}).get("i", ""), + "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), + } + else: + o = o_term.get("i", o_term.get("v", "")) + triples.append((s, p, o)) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception: {e}", file=sys.stderr) + + if debug: + print(f" [debug] got {len(triples)} triples", file=sys.stderr) + + return triples + + +async def _query_triples(ws_url, flow_id, prov_id, user, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): + """Query triples for a provenance node with retries for race condition""" + for attempt in range(max_retries): + triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=graph, debug=debug) + if triples: + return triples + # Wait before retry if empty (triples may not be stored yet) + if attempt < max_retries - 1: + if debug: + print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) + await asyncio.sleep(retry_delay) + return [] + + +async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False): + """ + Query for provenance of an edge (s, p, o) in the knowledge graph. + + Finds subgraphs that contain the edge via tg:contains, then follows + prov:wasDerivedFrom to find source documents. + + Returns list of source URIs (chunks, pages, documents). + """ + # Query for subgraphs that contain this edge: ?subgraph tg:contains <> + request = { + "id": "edge-prov-request", + "service": "triples", + "flow": flow_id, + "request": { + "p": {"t": "i", "i": TG_CONTAINS}, + "o": { + "t": "t", # Quoted triple type + "tr": { + "s": {"t": "i", "i": edge_s}, + "p": {"t": "i", "i": edge_p}, + "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, + } + }, + "user": user, + "collection": collection, + "limit": 10 + } + } + + if debug: + print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) + + stmt_uris = [] + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "edge-prov-request": + continue + + if "error" in response: + if debug: + print(f" [debug] error: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + for t in triple_list: + s = t.get("s", {}).get("i", "") + if s: + stmt_uris.append(s) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) + + if debug: + print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) + + # For each statement, query wasDerivedFrom to find sources + sources = [] + for stmt_uri in stmt_uris: + # Query: stmt_uri prov:wasDerivedFrom ?source + request = { + "id": "derived-from-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": stmt_uri}, + "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, + "user": user, + "collection": collection, + "limit": 10 + } + } + + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "derived-from-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + for t in triple_list: + o = t.get("o", {}).get("i", "") + if o: + sources.append(o) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) + + if debug: + print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) + + return sources + + +async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False): + """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" + request = { + "id": "parent-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": uri}, + "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, + "user": user, + "collection": collection, + "limit": 1 + } + } + + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "parent-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + if triple_list: + return triple_list[0].get("o", {}).get("i", None) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying parent: {e}", file=sys.stderr) + + return None + + +async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False): + """ + Trace the full provenance chain from a source URI up to the root document. + Returns a list of (uri, label) tuples from leaf to root. + """ + chain = [] + current = source_uri + max_depth = 10 # Prevent infinite loops + + for _ in range(max_depth): + if not current: + break + + # Get label for current entity + label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug) + chain.append((current, label)) + + # Get parent + parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug) + if not parent or parent == current: + break + current = parent + + return chain + + +def _format_provenance_chain(chain): + """ + Format a provenance chain as a human-readable string. + Chain is [(uri, label), ...] from leaf to root. + """ + if not chain: + return "" + + # Show labels, from leaf to root + labels = [label for uri, label in chain] + return " → ".join(labels) + + +def _is_iri(value): + """Check if a value looks like an IRI.""" + if not isinstance(value, str): + return False + return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") + + +async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False): + """ + Query for the rdfs:label of an IRI. + Uses label_cache to avoid repeated queries. + Returns the label if found, otherwise returns the IRI. + """ + if not _is_iri(iri): + return iri + + # Check cache first + if iri in label_cache: + return label_cache[iri] + + request = { + "id": "label-request", + "service": "triples", + "flow": flow_id, + "request": { + "s": {"t": "i", "i": iri}, + "p": {"t": "i", "i": RDFS_LABEL}, + "user": user, + "collection": collection, + "limit": 1 + } + } + + label = iri # Default to IRI if no label found + try: + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "label-request": + continue + + if "error" in response: + break + + if "response" in response: + resp = response["response"] + triple_list = resp.get("response", []) + if triple_list: + # Get the label value + o = triple_list[0].get("o", {}) + label = o.get("v", o.get("i", iri)) + + if resp.get("complete") or response.get("complete"): + break + except Exception as e: + if debug: + print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) + + # Cache the result + label_cache[iri] = label + return label + + +async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False): + """ + Resolve labels for all IRI components of an edge triple. + Returns (s_label, p_label, o_label). + """ + s = edge_triple.get("s", "?") + p = edge_triple.get("p", "?") + o = edge_triple.get("o", "?") + + s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug) + p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug) + o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug) + + return s_label, p_label, o_label + + +async def _question_explainable( + url, flow_id, question, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length, token=None, debug=False +): + """Execute graph RAG with explainability - shows provenance events with details""" + # Convert HTTP URL to WebSocket URL + if url.startswith("http://"): + ws_url = url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + ws_url = url.replace("https://", "wss://", 1) + else: + ws_url = f"ws://{url}" + + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + if token: + ws_url = f"{ws_url}?token={token}" + + # Cache for label lookups to avoid repeated queries + label_cache = {} + + request = { + "id": "cli-request", + "service": "graph-rag", + "flow": flow_id, + "request": { + "query": question, + "user": user, + "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, + "max-subgraph-size": max_subgraph_size, + "max-path-length": max_path_length, + "streaming": True + } + } + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: + await websocket.send(json.dumps(request)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != "cli-request": + continue + + if "error" in response: + print(f"\nError: {response['error']}", file=sys.stderr) + break + + if "response" in response: + resp = response["response"] + + # Check for errors in response + if "error" in resp and resp["error"]: + err = resp["error"] + print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) + break + + message_type = resp.get("message_type", "") + + if debug: + print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) + + if message_type == "explain": + # Display explain event with details + explain_id = resp.get("explain_id", "") + explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval) + if explain_id: + event_type = _get_event_type(explain_id) + print(f"\n [{event_type}] {explain_id}", file=sys.stderr) + + # Query triples for this explain node (using named graph filter) + triples = await _query_triples( + ws_url, flow_id, explain_id, user, collection, graph=explain_graph, debug=debug + ) + + # Format and display details + details = _format_provenance_details(event_type, triples) + for line in details: + print(line, file=sys.stderr) + + # For exploration events, resolve entity labels + if event_type == "exploration": + entity_iris = [o for s, p, o in triples if p == TG_ENTITY] + if entity_iris: + print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) + for iri in entity_iris: + label = await _query_label( + ws_url, flow_id, iri, user, collection, + label_cache, debug=debug + ) + print(f" - {label}", file=sys.stderr) + + # For focus events, query each edge selection for details + if event_type == "focus": + for s, p, o in triples: + if debug: + print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) + if p == TG_SELECTED_EDGE and isinstance(o, str): + if debug: + print(f" [debug] querying edge selection: {o}", file=sys.stderr) + # Query the edge selection entity (using named graph filter) + edge_triples = await _query_triples( + ws_url, flow_id, o, user, collection, graph=explain_graph, debug=debug + ) + if debug: + print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) + # Extract edge and reasoning + edge_triple = None # Store the actual triple for provenance lookup + reasoning = None + for es, ep, eo in edge_triples: + if debug: + print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) + if ep == TG_EDGE and isinstance(eo, dict): + # eo is a quoted triple dict + edge_triple = eo + elif ep == TG_REASONING: + reasoning = eo + if edge_triple: + # Resolve labels for edge components + s_label, p_label, o_label = await _resolve_edge_labels( + ws_url, flow_id, edge_triple, user, collection, + label_cache, debug=debug + ) + print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) + if reasoning: + r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning + print(f" Reason: {r_short}", file=sys.stderr) + + # Trace edge provenance in the user's collection (not explainability) + if edge_triple: + sources = await _query_edge_provenance( + ws_url, flow_id, + edge_triple.get("s", ""), + edge_triple.get("p", ""), + edge_triple.get("o", ""), + user, collection, # Use the query collection, not explainability + debug=debug + ) + if sources: + for src in sources: + # Trace full chain from source to root document + chain = await _trace_provenance_chain( + ws_url, flow_id, src, user, collection, + label_cache, debug=debug + ) + chain_str = _format_provenance_chain(chain) + print(f" Source: {chain_str}", file=sys.stderr) + + elif message_type == "chunk" or not message_type: + # Display response chunk + chunk = resp.get("response", "") + if chunk: + print(chunk, end="", flush=True) + + # Check if session is complete + if resp.get("end_of_session"): + break + + print() # Final newline + + +def _question_explainable_api( + url, flow_id, question_text, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length, token=None, debug=False +): + """Execute graph RAG with explainability using the new API classes.""" + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) + + try: + # Stream GraphRAG with explainability - process events as they arrive + for item in flow.graph_rag_explain( + query=question_text, + user=user, + collection=collection, + max_subgraph_size=max_subgraph_size, + max_subgraph_count=5, + max_entity_distance=max_path_length, + ): + if isinstance(item, RAGChunk): + # Print response content + print(item.content, end="", flush=True) + + elif isinstance(item, ProvenanceEvent): + # Process provenance event immediately + prov_id = item.explain_id + explain_graph = item.explain_graph or "urn:graph:retrieval" + + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + + if entity is None: + if debug: + print(f"\n [warning] Could not fetch entity: {prov_id}", file=sys.stderr) + continue + + # Display based on entity type + if isinstance(entity, Question): + print(f"\n [question] {prov_id}", file=sys.stderr) + if entity.query: + print(f" Query: {entity.query}", file=sys.stderr) + if entity.timestamp: + print(f" Time: {entity.timestamp}", file=sys.stderr) + + elif isinstance(entity, Grounding): + print(f"\n [grounding] {prov_id}", file=sys.stderr) + if entity.concepts: + print(f" Concepts: {len(entity.concepts)}", file=sys.stderr) + for concept in entity.concepts: + print(f" - {concept}", file=sys.stderr) + + elif isinstance(entity, Exploration): + print(f"\n [exploration] {prov_id}", file=sys.stderr) + if entity.edge_count: + print(f" Edges explored: {entity.edge_count}", file=sys.stderr) + if entity.entities: + print(f" Seed entities: {len(entity.entities)}", file=sys.stderr) + for ent in entity.entities: + label = explain_client.resolve_label(ent, user, collection) + print(f" - {label}", file=sys.stderr) + + elif isinstance(entity, Focus): + print(f"\n [focus] {prov_id}", file=sys.stderr) + if entity.selected_edge_uris: + print(f" Focused on {len(entity.selected_edge_uris)} edge(s)", file=sys.stderr) + + # Fetch full focus with edge details + focus_full = explain_client.fetch_focus_with_edges( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) + if focus_full and focus_full.edge_selections: + for edge_sel in focus_full.edge_selections: + if edge_sel.edge: + # Resolve labels for edge components + s_label, p_label, o_label = explain_client.resolve_edge_labels( + edge_sel.edge, user, collection + ) + print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) + if edge_sel.reasoning: + r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning + print(f" Reason: {r_short}", file=sys.stderr) + + elif isinstance(entity, Synthesis): + print(f"\n [synthesis] {prov_id}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + + else: + if debug: + print(f"\n [unknown] {prov_id} (type: {entity.entity_type})", file=sys.stderr) + + print() # Final newline + + finally: + socket.close() + + def question( url, flow_id, question, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, streaming=True, token=None + max_subgraph_size, max_path_length, streaming=True, token=None, + explainable=False, debug=False ): + # Explainable mode uses the API to capture and process provenance events + if explainable: + _question_explainable_api( + url=url, + flow_id=flow_id, + question_text=question, + user=user, + collection=collection, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, + token=token, + debug=debug + ) + return + # Create API client api = Api(url=url, token=token) @@ -138,6 +882,18 @@ def main(): help='Disable streaming (use non-streaming mode)' ) + parser.add_argument( + '-x', '--explainable', + action='store_true', + help='Show provenance events: Question, Grounding, Exploration, Focus, Synthesis (implies streaming)' + ) + + parser.add_argument( + '--debug', + action='store_true', + help='Show debug output for troubleshooting' + ) + args = parser.parse_args() try: @@ -154,6 +910,8 @@ def main(): max_path_length=args.max_path_length, streaming=not args.no_streaming, token=args.token, + explainable=args.explainable, + debug=args.debug, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py new file mode 100644 index 00000000..f545c53f --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -0,0 +1,167 @@ +""" +List all explainability sessions (GraphRAG and Agent) in a collection. + +Queries for all questions stored in the retrieval graph and displays them +with their session IDs, type (GraphRAG or Agent), and timestamps. + +Examples: + tg-list-explain-traces -U trustgraph -C default + tg-list-explain-traces --limit 20 --format json +""" + +import argparse +import json +import os +import sys +from tabulate import tabulate +from trustgraph.api import Api, ExplainabilityClient + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_user = 'trustgraph' +default_collection = 'default' + +# Retrieval graph +RETRIEVAL_GRAPH = "urn:graph:retrieval" + + +def truncate_text(text, max_len=60): + """Truncate text to max length with ellipsis.""" + if not text: + return "" + if len(text) <= max_len: + return text + return text[:max_len - 3] + "..." + + +def print_table(sessions): + """Print sessions as a table.""" + if not sessions: + print("No explainability sessions found.") + return + + rows = [] + for session in sessions: + rows.append([ + session["id"], + session.get("type", "Unknown"), + truncate_text(session["question"], 45), + session.get("time", "") + ]) + + headers = ["Session ID", "Type", "Question", "Time"] + print(tabulate(rows, headers=headers, tablefmt="simple")) + + +def print_json(sessions): + """Print sessions as JSON.""" + print(json.dumps(sessions, indent=2)) + + +def main(): + parser = argparse.ArgumentParser( + prog='tg-list-explain-traces', + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Auth token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})', + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection (default: {default_collection})', + ) + + parser.add_argument( + '-f', '--flow-id', + default='default', + help='Flow ID (default: default)', + ) + + parser.add_argument( + '--limit', + type=int, + default=50, + help='Max results (default: 50)', + ) + + parser.add_argument( + '--format', + choices=['table', 'json'], + default='table', + help='Output format: table (default), json', + ) + + args = parser.parse_args() + + try: + api = Api(args.api_url, token=args.token) + socket = api.socket() + flow = socket.flow(args.flow_id) + explain_client = ExplainabilityClient(flow) + + try: + # List all sessions using the API + questions = explain_client.list_sessions( + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + limit=args.limit, + ) + + # Convert to output format + sessions = [] + for q in questions: + session_type = explain_client.detect_session_type( + q.uri, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection + ) + + # Map type names + type_display = { + "graphrag": "GraphRAG", + "docrag": "DocRAG", + "agent": "Agent", + }.get(session_type, session_type.title()) + + sessions.append({ + "id": q.uri, + "type": type_display, + "question": q.query, + "time": q.timestamp, + }) + + if args.format == 'json': + print_json(sessions) + else: + print_table(sessions) + + finally: + socket.close() + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py index 7e7f4865..20c78515 100644 --- a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py @@ -44,14 +44,14 @@ async def load_de(running, queue, url): msg = { "metadata": { - "id": msg["m"]["i"], + "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ { - "chunk": chunk["c"], + "chunk_id": chunk["c"], "vectors": chunk["v"], } for chunk in msg["c"] diff --git a/trustgraph-cli/trustgraph/cli/load_pdf.py b/trustgraph-cli/trustgraph/cli/load_pdf.py deleted file mode 100644 index d305cb4b..00000000 --- a/trustgraph-cli/trustgraph/cli/load_pdf.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Loads a PDF document into TrustGraph processing by directing to -the pdf-decoder queue. -Consider using tg-add-library-document to load -a document, followed by tg-start-library-processing to initiate processing. -""" - -import hashlib -import argparse -import os -import time -import uuid - -from trustgraph.api import Api -from trustgraph.knowledge import hash, to_uri -from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG -from trustgraph.knowledge import Organization, PublicationEvent -from trustgraph.knowledge import DigitalDocument - -default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' -default_collection = 'default' - -class Loader: - - def __init__( - self, - url, - flow_id, - user, - collection, - metadata, - ): - - self.api = Api(url).flow().id(flow_id) - - self.user = user - self.collection = collection - self.metadata = metadata - - def load(self, files): - - for file in files: - self.load_file(file) - - def load_file(self, file): - - try: - - path = file - data = open(path, "rb").read() - - # Create a SHA256 hash from the data - id = hash(data) - - id = to_uri(PREF_DOC, id) - - self.metadata.id = id - - self.api.load_document( - document=data, id=id, metadata=self.metadata, - user=self.user, - collection=self.collection, - ) - - print(f"{file}: Loaded successfully.") - - except Exception as e: - print(f"{file}: Failed: {str(e)}", flush=True) - raise e - -def main(): - - parser = argparse.ArgumentParser( - prog='tg-load-pdf', - description=__doc__, - ) - - parser.add_argument( - '-u', '--url', - default=default_url, - help=f'API URL (default: {default_url})', - ) - - parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' - ) - - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - - parser.add_argument( - '-C', '--collection', - default=default_collection, - help=f'Collection ID (default: {default_collection})' - ) - - parser.add_argument( - '--name', help=f'Document name' - ) - - parser.add_argument( - '--description', help=f'Document description' - ) - - parser.add_argument( - '--copyright-notice', help=f'Copyright notice' - ) - - parser.add_argument( - '--copyright-holder', help=f'Copyright holder' - ) - - parser.add_argument( - '--copyright-year', help=f'Copyright year' - ) - - parser.add_argument( - '--license', help=f'Copyright license' - ) - - parser.add_argument( - '--publication-organization', help=f'Publication organization' - ) - - parser.add_argument( - '--publication-description', help=f'Publication description' - ) - - parser.add_argument( - '--publication-date', help=f'Publication date' - ) - - parser.add_argument( - '--document-url', help=f'Document URL' - ) - - parser.add_argument( - '--keyword', nargs='+', help=f'Keyword' - ) - - parser.add_argument( - '--identifier', '--id', help=f'Document ID' - ) - - parser.add_argument( - 'files', nargs='+', - help=f'File to load' - ) - - args = parser.parse_args() - - try: - - document = DigitalDocument( - id, - name=args.name, - description=args.description, - copyright_notice=args.copyright_notice, - copyright_holder=args.copyright_holder, - copyright_year=args.copyright_year, - license=args.license, - url=args.document_url, - keywords=args.keyword, - ) - - if args.publication_organization: - org = Organization( - id=to_uri(PREF_ORG, hash(args.publication_organization)), - name=args.publication_organization, - ) - document.publication = PublicationEvent( - id = to_uri(PREF_PUBEV, str(uuid.uuid4())), - organization=org, - description=args.publication_description, - start_date=args.publication_date, - end_date=args.publication_date, - ) - - p = Loader( - url=args.url, - flow_id = args.flow_id, - user=args.user, - collection=args.collection, - metadata=document, - ) - - p.load(args.files) - - except Exception as e: - - print("Exception:", e, flush=True) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/trustgraph-cli/trustgraph/cli/load_text.py b/trustgraph-cli/trustgraph/cli/load_text.py deleted file mode 100644 index 594d1c04..00000000 --- a/trustgraph-cli/trustgraph/cli/load_text.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Loads a text document into TrustGraph processing by directing to a text -loader queue. -Consider using tg-add-library-document to load -a document, followed by tg-start-library-processing to initiate processing. -""" - -import pulsar -from pulsar.schema import JsonSchema -import hashlib -import argparse -import os -import time -import uuid - -from trustgraph.api import Api -from trustgraph.knowledge import hash, to_uri -from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG -from trustgraph.knowledge import Organization, PublicationEvent -from trustgraph.knowledge import DigitalDocument - -default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' -default_collection = 'default' - -class Loader: - - def __init__( - self, - url, - flow_id, - user, - collection, - metadata, - ): - - self.api = Api(url).flow().id(flow_id) - - self.user = user - self.collection = collection - self.metadata = metadata - - def load(self, files): - - for file in files: - self.load_file(file) - - def load_file(self, file): - - try: - - path = file - data = open(path, "rb").read() - - # Create a SHA256 hash from the data - id = hash(data) - - id = to_uri(PREF_DOC, id) - - self.metadata.id = id - - self.api.load_text( - text=data, id=id, metadata=self.metadata, - user=self.user, - collection=self.collection, - ) - - print(f"{file}: Loaded successfully.") - - except Exception as e: - print(f"{file}: Failed: {str(e)}", flush=True) - raise e - -def main(): - - parser = argparse.ArgumentParser( - prog='tg-load-text', - description=__doc__, - ) - - parser.add_argument( - '-u', '--url', - default=default_url, - help=f'API URL (default: {default_url})', - ) - - parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' - ) - - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - - parser.add_argument( - '-C', '--collection', - default=default_collection, - help=f'Collection ID (default: {default_collection})' - ) - - parser.add_argument( - '--name', help=f'Document name' - ) - - parser.add_argument( - '--description', help=f'Document description' - ) - - parser.add_argument( - '--copyright-notice', help=f'Copyright notice' - ) - - parser.add_argument( - '--copyright-holder', help=f'Copyright holder' - ) - - parser.add_argument( - '--copyright-year', help=f'Copyright year' - ) - - parser.add_argument( - '--license', help=f'Copyright license' - ) - - parser.add_argument( - '--publication-organization', help=f'Publication organization' - ) - - parser.add_argument( - '--publication-description', help=f'Publication description' - ) - - parser.add_argument( - '--publication-date', help=f'Publication date' - ) - - parser.add_argument( - '--document-url', help=f'Document URL' - ) - - parser.add_argument( - '--keyword', nargs='+', help=f'Keyword' - ) - - parser.add_argument( - '--identifier', '--id', help=f'Document ID' - ) - - parser.add_argument( - 'files', nargs='+', - help=f'File to load' - ) - - args = parser.parse_args() - - - try: - - document = DigitalDocument( - id, - name=args.name, - description=args.description, - copyright_notice=args.copyright_notice, - copyright_holder=args.copyright_holder, - copyright_year=args.copyright_year, - license=args.license, - url=args.document_url, - keywords=args.keyword, - ) - - if args.publication_organization: - org = Organization( - id=to_uri(PREF_ORG, hash(args.publication_organization)), - name=args.publication_organization, - ) - document.publication = PublicationEvent( - id = to_uri(PREF_PUBEV, str(uuid.uuid4())), - organization=org, - description=args.publication_description, - start_date=args.publication_date, - end_date=args.publication_date, - ) - - p = Loader( - url = args.url, - flow_id = args.flow_id, - user = args.user, - collection = args.collection, - metadata = document, - ) - - p.load(args.files) - - print("All done.") - - except Exception as e: - - print("Exception:", e, flush=True) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/trustgraph-cli/trustgraph/cli/query_graph.py b/trustgraph-cli/trustgraph/cli/query_graph.py new file mode 100644 index 00000000..a2c38353 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/query_graph.py @@ -0,0 +1,578 @@ +""" +Query the triple store with pattern matching and configurable output formats. + +Unlike tg-show-graph which dumps the entire graph, this tool enables selective +queries by specifying any combination of subject, predicate, object, and graph. + +Auto-detection rules for values: + - Starts with http://, https://, urn:, or wrapped in <> -> IRI + - Starts with << -> quoted triple (Turtle-style) + - Anything else -> literal + +Examples: + tg-query-graph -s "http://example.org/entity" + tg-query-graph -p "http://www.w3.org/2000/01/rdf-schema#label" + tg-query-graph -o "Marie Curie" --object-language en + tg-query-graph -o "<>" +""" + +import argparse +import json +import os +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + + +def parse_inline_quoted_triple(value): + """Parse inline Turtle-style quoted triple: <> + + Args: + value: String in format "<>" + + Returns: + dict: Wire-format quoted triple term, or None if parsing fails + """ + # Strip << and >> markers + inner = value[2:-2].strip() + + # Split on whitespace, but respect quoted strings + # Simple approach: split and handle common cases + parts = [] + current = "" + in_quotes = False + quote_char = None + + for char in inner: + if char in ('"', "'") and not in_quotes: + in_quotes = True + quote_char = char + current += char + elif char == quote_char and in_quotes: + in_quotes = False + quote_char = None + current += char + elif char.isspace() and not in_quotes: + if current: + parts.append(current) + current = "" + else: + current += char + + if current: + parts.append(current) + + if len(parts) != 3: + raise ValueError( + f"Quoted triple must have exactly 3 parts (s p o), got {len(parts)}: {parts}" + ) + + s_val, p_val, o_val = parts + + # Build the inner triple terms + s_term = build_term(s_val) + p_term = build_term(p_val) + o_term = build_term(o_val) + + return { + "t": "t", + "tr": { + "s": s_term, + "p": p_term, + "o": o_term + } + } + + +def build_term(value, term_type=None, datatype=None, language=None): + """Build wire-format Term dict from CLI input. + + Auto-detection rules (when term_type is None): + - Starts with http://, https://, urn: -> IRI + - Wrapped in <> (e.g., ) -> IRI (angle brackets stripped) + - Starts with << and ends with >> -> quoted triple + - Anything else -> literal + + Args: + value: The term value + term_type: One of 'iri', 'literal', 'triple', or None for auto-detect + datatype: Datatype for literal objects (e.g., xsd:integer) + language: Language tag for literal objects (e.g., en) + + Returns: + dict: Wire-format Term dict, or None if value is None + """ + if value is None: + return None + + # Auto-detect type if not specified + if term_type is None: + if value.startswith("<<") and value.endswith(">>"): + term_type = "triple" + elif value.startswith("<") and value.endswith(">") and not value.startswith("<<"): + # Angle-bracket wrapped IRI: + value = value[1:-1] # Strip < and > + term_type = "iri" + elif value.startswith(("http://", "https://", "urn:")): + term_type = "iri" + else: + term_type = "literal" + + if term_type == "iri": + # Strip angle brackets if present + if value.startswith("<") and value.endswith(">"): + value = value[1:-1] + return {"t": "i", "i": value} + elif term_type == "literal": + result = {"t": "l", "v": value} + if datatype: + result["dt"] = datatype + if language: + result["ln"] = language + return result + elif term_type == "triple": + # Check if it's inline Turtle-style + if value.startswith("<<") and value.endswith(">>"): + return parse_inline_quoted_triple(value) + else: + # Assume it's raw JSON (legacy support) + triple_data = json.loads(value) + return {"t": "t", "tr": triple_data} + else: + raise ValueError(f"Unknown term type: {term_type}") + + +def build_quoted_triple_term(qt_subject, qt_subject_type, + qt_predicate, + qt_object, qt_object_type, + qt_object_datatype, qt_object_language): + """Build a quoted triple term from --qt-* arguments. + + Returns: + dict: Wire-format quoted triple term, or None if no qt args provided + """ + # Check if any qt args were provided + if not any([qt_subject, qt_predicate, qt_object]): + return None + + # Subject (IRI or nested triple) + s_term = build_term(qt_subject, term_type=qt_subject_type) + + # Predicate (always IRI) + p_term = build_term(qt_predicate, term_type='iri') + + # Object (IRI, literal, or nested triple) + o_term = build_term( + qt_object, + term_type=qt_object_type, + datatype=qt_object_datatype, + language=qt_object_language + ) + + return { + "t": "t", + "tr": { + "s": s_term, + "p": p_term, + "o": o_term + } + } + + +def format_term(term_dict): + """Format a term dict for display in space/pipe output formats. + + Handles multiple wire format styles: + - Short form (send): {"t": "i", "i": "..."}, {"t": "l", "v": "..."} + - Long form (receive): {"type": "i", "iri": "..."}, {"type": "l", "value": "..."} + - Raw quoted triple: {"s": {...}, "p": {...}, "o": {...}} (no type wrapper) + - Stringified quoted triple in IRI: {"t": "i", "i": "{\"s\":...}"} (backend quirk) + + Args: + term_dict: Wire-format term dict + + Returns: + str: Formatted string representation + """ + if not term_dict: + return "" + + # Get type - handle both short and long form + t = term_dict.get("t") or term_dict.get("type") + + if t == "i": + # IRI - handle both "i" and "iri" keys + iri_value = term_dict.get("i") or term_dict.get("iri", "") + # Check if IRI value is actually a stringified quoted triple (backend quirk) + if iri_value.startswith('{"s":') or iri_value.startswith("{\"s\":"): + try: + parsed = json.loads(iri_value) + if "s" in parsed and "p" in parsed and "o" in parsed: + # It's a stringified quoted triple - format it properly + s = format_term(parsed.get("s", {})) + p = format_term(parsed.get("p", {})) + o = format_term(parsed.get("o", {})) + return f"<<{s} {p} {o}>>" + except json.JSONDecodeError: + pass # Not valid JSON, treat as regular IRI + return iri_value + elif t == "l": + # Literal - handle both short and long form keys + value = term_dict.get("v") or term_dict.get("value", "") + result = f'"{value}"' + # Language tag + lang = term_dict.get("ln") or term_dict.get("language") + if lang: + result += f'@{lang}' + else: + # Datatype + dt = term_dict.get("dt") or term_dict.get("datatype") + if dt: + result += f'^^{dt}' + return result + elif t == "t": + # Quoted triple - handle both "tr" and "triple" keys + tr = term_dict.get("tr") or term_dict.get("triple", {}) + s = format_term(tr.get("s", {})) + p = format_term(tr.get("p", {})) + o = format_term(tr.get("o", {})) + return f"<<{s} {p} {o}>>" + elif t is None and "s" in term_dict and "p" in term_dict and "o" in term_dict: + # Raw quoted triple without type wrapper (has s, p, o keys directly) + s = format_term(term_dict.get("s", {})) + p = format_term(term_dict.get("p", {})) + o = format_term(term_dict.get("o", {})) + return f"<<{s} {p} {o}>>" + + return str(term_dict) + + +def output_space(triples, headers=False): + """Output triples in space-separated format.""" + if headers: + print("subject predicate object") + for triple in triples: + s = format_term(triple.get("s", {})) + p = format_term(triple.get("p", {})) + o = format_term(triple.get("o", {})) + print(s, p, o) + + +def output_pipe(triples, headers=False): + """Output triples in pipe-separated format.""" + if headers: + print("subject|predicate|object") + for triple in triples: + s = format_term(triple.get("s", {})) + p = format_term(triple.get("p", {})) + o = format_term(triple.get("o", {})) + print(f"{s}|{p}|{o}") + + +def output_json(triples): + """Output triples as a JSON array.""" + print(json.dumps(triples, indent=2)) + + +def output_jsonl(triples): + """Output triples as JSON Lines (one object per line).""" + for triple in triples: + print(json.dumps(triple)) + + +def query_graph( + url, flow_id, user, collection, limit, batch_size, + subject=None, predicate=None, obj=None, graph=None, + output_format="space", headers=False, token=None +): + """Query the triple store with pattern matching. + + Uses the API's triples_query_stream for efficient streaming delivery. + """ + socket = Api(url, token=token).socket() + flow = socket.flow(flow_id) + + all_triples = [] + + try: + # Use triples_query_stream - accepts Term dicts directly + for triples in flow.triples_query_stream( + s=subject, + p=predicate, + o=obj, + g=graph, + user=user, + collection=collection, + limit=limit, + batch_size=batch_size, + ): + if not isinstance(triples, list): + triples = [triples] if triples else [] + + if output_format in ("json",): + # Collect all triples for JSON array output + all_triples.extend(triples) + else: + # Stream output for other formats + if output_format == "space": + output_space(triples, headers=headers and not all_triples) + elif output_format == "pipe": + output_pipe(triples, headers=headers and not all_triples) + elif output_format == "jsonl": + output_jsonl(triples) + # Track that we've output something (for headers logic) + all_triples.extend([None] * len(triples)) + + # Output collected JSON array + if output_format == "json": + output_json(all_triples) + + finally: + socket.close() + + +def main(): + parser = argparse.ArgumentParser( + prog='tg-query-graph', + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Outer triple filters + outer_group = parser.add_argument_group('Outer triple filters') + + outer_group.add_argument( + '-s', '--subject', + metavar='VALUE', + help='Subject filter (auto-detected as IRI or literal)', + ) + + outer_group.add_argument( + '-p', '--predicate', + metavar='VALUE', + help='Predicate filter (auto-detected as IRI)', + ) + + outer_group.add_argument( + '-o', '--object', + dest='obj', + metavar='VALUE', + help='Object filter (IRI, literal, or <>)', + ) + + outer_group.add_argument( + '--object-type', + choices=['iri', 'literal', 'triple'], + metavar='TYPE', + help='Override object type detection: iri, literal, triple', + ) + + outer_group.add_argument( + '--object-datatype', + metavar='DATATYPE', + help='Datatype for literal object (e.g., xsd:integer)', + ) + + outer_group.add_argument( + '--object-language', + metavar='LANG', + help='Language tag for literal object (e.g., en)', + ) + + outer_group.add_argument( + '-g', '--graph', + metavar='VALUE', + help='Named graph filter', + ) + + # Quoted triple filters (alternative to inline <> syntax) + qt_group = parser.add_argument_group( + 'Quoted triple filters', + 'Build object as quoted triple using explicit fields (alternative to -o "<>")' + ) + + qt_group.add_argument( + '--qt-subject', + metavar='VALUE', + help='Quoted triple subject', + ) + + qt_group.add_argument( + '--qt-subject-type', + choices=['iri', 'triple'], + metavar='TYPE', + help='Override qt-subject type: iri, triple', + ) + + qt_group.add_argument( + '--qt-predicate', + metavar='VALUE', + help='Quoted triple predicate (always IRI)', + ) + + qt_group.add_argument( + '--qt-object', + metavar='VALUE', + help='Quoted triple object', + ) + + qt_group.add_argument( + '--qt-object-type', + choices=['iri', 'literal', 'triple'], + metavar='TYPE', + help='Override qt-object type: iri, literal, triple', + ) + + qt_group.add_argument( + '--qt-object-datatype', + metavar='DATATYPE', + help='Datatype for qt-object literal', + ) + + qt_group.add_argument( + '--qt-object-language', + metavar='LANG', + help='Language tag for qt-object literal', + ) + + # Standard parameters + std_group = parser.add_argument_group('Standard parameters') + + std_group.add_argument( + '-u', '--api-url', + default=default_url, + metavar='URL', + help=f'API URL (default: {default_url})', + ) + + std_group.add_argument( + '-f', '--flow-id', + default="default", + metavar='ID', + help='Flow ID (default: default)' + ) + + std_group.add_argument( + '-U', '--user', + default=default_user, + metavar='USER', + help=f'User/keyspace (default: {default_user})' + ) + + std_group.add_argument( + '-C', '--collection', + default=default_collection, + metavar='COLL', + help=f'Collection (default: {default_collection})' + ) + + std_group.add_argument( + '-t', '--token', + default=default_token, + metavar='TOKEN', + help='Auth token (default: $TRUSTGRAPH_TOKEN)', + ) + + std_group.add_argument( + '-l', '--limit', + type=int, + default=1000, + metavar='N', + help='Max results (default: 1000)', + ) + + std_group.add_argument( + '-b', '--batch-size', + type=int, + default=20, + metavar='N', + help='Streaming batch size (default: 20)', + ) + + # Output options + out_group = parser.add_argument_group('Output options') + + out_group.add_argument( + '--format', + choices=['space', 'pipe', 'json', 'jsonl'], + default='space', + metavar='FORMAT', + help='Output format: space, pipe, json, jsonl (default: space)', + ) + + out_group.add_argument( + '-H', '--headers', + action='store_true', + help='Show column headers (for space/pipe formats)', + ) + + args = parser.parse_args() + + try: + # Build term dicts from CLI arguments + subject_term = build_term(args.subject) if args.subject else None + predicate_term = build_term(args.predicate) if args.predicate else None + + # Check for --qt-* args to build quoted triple as object + qt_term = build_quoted_triple_term( + qt_subject=args.qt_subject, + qt_subject_type=args.qt_subject_type, + qt_predicate=args.qt_predicate, + qt_object=args.qt_object, + qt_object_type=args.qt_object_type, + qt_object_datatype=args.qt_object_datatype, + qt_object_language=args.qt_object_language, + ) + + # Object: use --qt-* args if provided, otherwise use -o + if qt_term is not None: + if args.obj: + parser.error("Cannot use both -o/--object and --qt-* arguments") + obj_term = qt_term + elif args.obj: + obj_term = build_term( + args.obj, + term_type=args.object_type, + datatype=args.object_datatype, + language=args.object_language + ) + else: + obj_term = None + + # Graph is a plain IRI string, not a Term + # None = all graphs, "" = default graph only, "uri" = specific graph + graph_value = args.graph + + query_graph( + url=args.api_url, + flow_id=args.flow_id, + user=args.user, + collection=args.collection, + limit=args.limit, + batch_size=args.batch_size, + subject=subject_term, + predicate=predicate_term, + obj=obj_term, + graph=graph_value, + output_format=args.format, + headers=args.headers, + token=args.token, + ) + + except json.JSONDecodeError as e: + print(f"Error parsing JSON: {e}", file=sys.stderr) + sys.exit(1) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"Exception: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py index 8fdd335d..ca8d25de 100644 --- a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py @@ -50,14 +50,14 @@ async def fetch_de(running, queue, user, collection, url): "de", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ { - "c": chunk["chunk"], + "c": chunk["chunk_id"], "v": chunk["vectors"], } for chunk in data["chunks"] diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py new file mode 100644 index 00000000..ba4f9c25 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -0,0 +1,582 @@ +""" +Show full explainability trace for a GraphRAG or Agent session. + +Given a question/session URI, displays the complete trace: +- GraphRAG: Question -> Exploration -> Focus (edge selection) -> Synthesis (answer) +- Agent: Session -> Iteration(s) (thought/action/observation) -> Final Answer + +The tool auto-detects the trace type based on rdf:type. + +Examples: + tg-show-explain-trace -U trustgraph -C default "urn:trustgraph:question:abc123" + tg-show-explain-trace -U trustgraph -C default "urn:trustgraph:agent:abc123" + tg-show-explain-trace --max-answer 1000 "urn:trustgraph:question:abc123" + tg-show-explain-trace --show-provenance "urn:trustgraph:question:abc123" +""" + +import argparse +import json +import os +import sys +from trustgraph.api import ( + Api, + ExplainabilityClient, + Question, + Exploration, + Focus, + Synthesis, + Analysis, + Conclusion, +) + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_user = 'trustgraph' +default_collection = 'default' + +# Graphs +RETRIEVAL_GRAPH = "urn:graph:retrieval" +SOURCE_GRAPH = "urn:graph:source" + +# Provenance predicates for edge tracing +TG = "https://trustgraph.ai/ns/" +TG_CONTAINS = TG + "contains" +PROV = "http://www.w3.org/ns/prov#" +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" + + +def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client): + """ + Trace an edge back to its source document via reification. + + Args: + flow: SocketFlowInstance + user: User identifier + collection: Collection identifier + edge: Dict with s, p, o keys + label_cache: Dict for caching labels + explain_client: ExplainabilityClient for label resolution + + Returns: + List of provenance chains, each chain is list of {uri, label} + """ + edge_s = edge.get("s", "") + edge_p = edge.get("p", "") + edge_o = edge.get("o", "") + + # Build quoted triple for lookup + def build_term(val): + if isinstance(val, str) and (val.startswith("http") or val.startswith("urn:")): + return {"t": "i", "i": val} + return {"t": "l", "v": str(val)} + + quoted_triple = { + "t": "t", + "tr": { + "s": build_term(edge_s), + "p": build_term(edge_p), + "o": build_term(edge_o), + } + } + + # Query: ?subgraph tg:contains <> + try: + results = flow.triples_query( + p=TG_CONTAINS, + o=quoted_triple, + g=SOURCE_GRAPH, + user=user, + collection=collection, + limit=10 + ) + except Exception: + return [] + + # Extract statement URIs + stmt_uris = [] + for t in results: + s_term = t.get("s", {}) + s_val = s_term.get("i") or s_term.get("v", "") + if s_val: + stmt_uris.append(s_val) + + # For each statement, trace wasDerivedFrom chain + provenance_chains = [] + for stmt_uri in stmt_uris: + chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client) + if chain: + provenance_chains.append(chain) + + return provenance_chains + + +def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10): + """Trace prov:wasDerivedFrom chain from start_uri to root.""" + chain = [] + current = start_uri + + for _ in range(max_depth): + if not current: + break + + # Get label + if current in label_cache: + label = label_cache[current] + else: + label = explain_client.resolve_label(current, user, collection) + label_cache[current] = label + + chain.append({"uri": current, "label": label}) + + # Get parent via wasDerivedFrom + try: + results = flow.triples_query( + s=current, + p=PROV_WAS_DERIVED_FROM, + g=SOURCE_GRAPH, + user=user, + collection=collection, + limit=1 + ) + except Exception: + break + + parent = None + for t in results: + o_term = t.get("o", {}) + parent = o_term.get("i") or o_term.get("v", "") + break + + if not parent or parent == current: + break + current = parent + + return chain + + +def format_provenance_chain(chain): + """Format a provenance chain for display.""" + if not chain: + return "" + labels = [item.get("label", item.get("uri", "?")) for item in chain] + return " -> ".join(labels) + + +def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, show_provenance=False): + """Print GraphRAG trace in text format.""" + question = trace.get("question") + + print(f"=== GraphRAG Session: {question.uri if question else 'Unknown'} ===") + print() + + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") + print() + + # Exploration + print("--- Exploration ---") + exploration = trace.get("exploration") + if exploration: + print(f"Retrieved {exploration.edge_count} edges from knowledge graph") + else: + print("No exploration data found") + print() + + # Focus + print("--- Focus (Edge Selection) ---") + focus = trace.get("focus") + if focus: + edges = focus.edge_selections + print(f"Selected {len(edges)} edges:") + print() + + label_cache = {} + + for i, edge_sel in enumerate(edges, 1): + if edge_sel.edge: + s_label, p_label, o_label = explain_client.resolve_edge_labels( + edge_sel.edge, user, collection + ) + print(f" {i}. ({s_label}, {p_label}, {o_label})") + + if edge_sel.reasoning: + r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning + print(f" Reasoning: {r_short}") + + if show_provenance and edge_sel.edge: + provenance = trace_edge_provenance( + flow, user, collection, edge_sel.edge, + label_cache, explain_client + ) + for chain in provenance: + chain_str = format_provenance_chain(chain) + if chain_str: + print(f" Source: {chain_str}") + + print() + else: + print("No focus data found") + print() + + # Synthesis + print("--- Synthesis ---") + synthesis = trace.get("synthesis") + if synthesis: + content = "" + if synthesis.document and api: + content = explain_client.fetch_document_content( + synthesis.document, api, user + ) + if content: + print("Answer:") + for line in content.split("\n"): + print(f" {line}") + elif synthesis.document: + print(f"Document: {synthesis.document}") + else: + print("No answer content found") + else: + print("No synthesis data found") + + +def print_docrag_text(trace, explain_client, api, user): + """Print DocRAG trace in text format.""" + question = trace.get("question") + + print(f"=== DocRAG Session: {question.uri if question else 'Unknown'} ===") + print() + + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") + print() + + # Grounding + grounding = trace.get("grounding") + if grounding: + print("--- Grounding ---") + print(f"Concepts: {', '.join(grounding.concepts)}") + print() + + # Exploration + print("--- Exploration ---") + exploration = trace.get("exploration") + if exploration: + print(f"Retrieved {exploration.chunk_count} chunks from document store") + else: + print("No exploration data found") + print() + + # Synthesis (no Focus step for DocRAG) + print("--- Synthesis ---") + synthesis = trace.get("synthesis") + if synthesis: + content = "" + if synthesis.document and api: + content = explain_client.fetch_document_content( + synthesis.document, api, user + ) + if content: + print("Answer:") + for line in content.split("\n"): + print(f" {line}") + elif synthesis.document: + print(f"Document: {synthesis.document}") + else: + print("No answer content found") + else: + print("No synthesis data found") + + +def print_agent_text(trace, explain_client, api, user): + """Print Agent trace in text format.""" + question = trace.get("question") + + print(f"=== Agent Session: {question.uri if question else 'Unknown'} ===") + print() + + if question: + print(f"Question: {question.query}") + if question.timestamp: + print(f"Time: {question.timestamp}") + print() + + # Analysis steps + print("--- Analysis ---") + iterations = trace.get("iterations", []) + if iterations: + for i, analysis in enumerate(iterations, 1): + print(f"Analysis {i}:") + print(f" Thought: {analysis.thought or 'N/A'}") + print(f" Action: {analysis.action or 'N/A'}") + + + if analysis.arguments: + # Try to pretty-print JSON arguments + try: + args_obj = json.loads(analysis.arguments) + args_str = json.dumps(args_obj, indent=4) + print(f" Arguments:") + for line in args_str.split('\n'): + print(f" {line}") + except Exception: + print(f" Arguments: {analysis.arguments}") + else: + print(f" Arguments: N/A") + + obs = analysis.observation or 'N/A' + if obs and len(obs) > 200: + obs = obs[:200] + "... [truncated]" + print(f" Observation: {obs}") + print() + else: + print("No analysis steps recorded") + print() + + # Conclusion + print("--- Conclusion ---") + conclusion = trace.get("conclusion") + if conclusion: + content = "" + if conclusion.document and api: + content = explain_client.fetch_document_content( + conclusion.document, api, user + ) + if content: + print("Answer:") + for line in content.split("\n"): + print(f" {line}") + elif conclusion.document: + print(f"Document: {conclusion.document}") + else: + print("No conclusion recorded") + else: + print("No conclusion recorded") + + +def trace_to_dict(trace, trace_type): + """Convert trace entities to JSON-serializable dict.""" + if trace_type == "agent": + question = trace.get("question") + return { + "type": "agent", + "session_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "iterations": [ + { + "id": a.uri, + "thought": a.thought, + "action": a.action, + "arguments": a.arguments, + "observation": a.observation, + } + for a in trace.get("iterations", []) + ], + "conclusion": { + "id": trace["conclusion"].uri, + "document": trace["conclusion"].document, + } if trace.get("conclusion") else None, + } + elif trace_type == "docrag": + question = trace.get("question") + grounding = trace.get("grounding") + exploration = trace.get("exploration") + synthesis = trace.get("synthesis") + + return { + "type": "docrag", + "question_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "grounding": { + "id": grounding.uri, + "concepts": grounding.concepts, + } if grounding else None, + "exploration": { + "id": exploration.uri, + "chunk_count": exploration.chunk_count, + } if exploration else None, + "synthesis": { + "id": synthesis.uri, + "document": synthesis.document, + } if synthesis else None, + } + else: + # graphrag + question = trace.get("question") + exploration = trace.get("exploration") + focus = trace.get("focus") + synthesis = trace.get("synthesis") + + return { + "type": "graphrag", + "question_id": question.uri if question else None, + "question": question.query if question else None, + "time": question.timestamp if question else None, + "exploration": { + "id": exploration.uri, + "edge_count": exploration.edge_count, + } if exploration else None, + "focus": { + "id": focus.uri, + "selected_edges": [ + { + "edge": edge_sel.edge, + "reasoning": edge_sel.reasoning, + } + for edge_sel in focus.edge_selections + ], + } if focus else None, + "synthesis": { + "id": synthesis.uri, + "document": synthesis.document, + } if synthesis else None, + } + + +def main(): + parser = argparse.ArgumentParser( + prog='tg-show-explain-trace', + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + 'question_id', + help='Question/session URI to show trace for', + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Auth token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})', + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection (default: {default_collection})', + ) + + parser.add_argument( + '-f', '--flow-id', + default='default', + help='Flow ID (default: default)', + ) + + parser.add_argument( + '--max-answer', + type=int, + default=500, + help='Max chars for answer display (default: 500)', + ) + + parser.add_argument( + '--show-provenance', + action='store_true', + help='Also trace edges back to source documents', + ) + + parser.add_argument( + '--format', + choices=['text', 'json'], + default='text', + help='Output format: text (default), json', + ) + + args = parser.parse_args() + + try: + api = Api(args.api_url, token=args.token) + socket = api.socket() + flow = socket.flow(args.flow_id) + explain_client = ExplainabilityClient(flow) + + try: + # Detect trace type + trace_type = explain_client.detect_session_type( + args.question_id, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + ) + + if trace_type == "agent": + # Fetch and display agent trace + trace = explain_client.fetch_agent_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + api=api, + max_content=args.max_answer, + ) + + if args.format == 'json': + print(json.dumps(trace_to_dict(trace, "agent"), indent=2)) + else: + print_agent_text(trace, explain_client, api, args.user) + + elif trace_type == "docrag": + # Fetch and display DocRAG trace + trace = explain_client.fetch_docrag_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + api=api, + max_content=args.max_answer, + ) + + if args.format == 'json': + print(json.dumps(trace_to_dict(trace, "docrag"), indent=2)) + else: + print_docrag_text(trace, explain_client, api, args.user) + + else: + # Fetch and display GraphRAG trace + trace = explain_client.fetch_graphrag_trace( + args.question_id, + graph=RETRIEVAL_GRAPH, + user=args.user, + collection=args.collection, + api=api, + max_content=args.max_answer, + ) + + if args.format == 'json': + print(json.dumps(trace_to_dict(trace, "graphrag"), indent=2)) + else: + print_graphrag_text( + trace, explain_client, flow, + args.user, args.collection, + api=api, + show_provenance=args.show_provenance + ) + + finally: + socket.close() + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py new file mode 100644 index 00000000..4f87712c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py @@ -0,0 +1,407 @@ +""" +Show extraction provenance: Document -> Pages -> Chunks -> Edges. + +Given a document ID, traverses and displays all derived entities +(pages, chunks, extracted edges) using prov:wasDerivedFrom relationships. + +Examples: + tg-show-extraction-provenance -U trustgraph -C default "urn:trustgraph:doc:abc123" + tg-show-extraction-provenance --show-content --max-content 500 "urn:trustgraph:doc:abc123" +""" + +import argparse +import json +import os +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_user = 'trustgraph' +default_collection = 'default' + +# Predicates +PROV_WAS_DERIVED_FROM = "http://www.w3.org/ns/prov#wasDerivedFrom" +RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" +RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" +TG = "https://trustgraph.ai/ns/" +TG_CONTAINS = TG + "contains" +TG_DOCUMENT_TYPE = TG + "Document" +TG_PAGE_TYPE = TG + "Page" +TG_CHUNK_TYPE = TG + "Chunk" +TG_SUBGRAPH_TYPE = TG + "Subgraph" +DC_TITLE = "http://purl.org/dc/terms/title" +DC_FORMAT = "http://purl.org/dc/terms/format" + +# Map TrustGraph type URIs to display names +TYPE_MAP = { + TG_DOCUMENT_TYPE: "document", + TG_PAGE_TYPE: "page", + TG_CHUNK_TYPE: "chunk", + TG_SUBGRAPH_TYPE: "subgraph", +} + +# Source graph +SOURCE_GRAPH = "urn:graph:source" + + +def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): + """Query triples using the socket API.""" + request = { + "user": user, + "collection": collection, + "limit": limit, + "streaming": False, + } + + if s is not None: + request["s"] = {"t": "i", "i": s} + if p is not None: + request["p"] = {"t": "i", "i": p} + if o is not None: + if isinstance(o, str): + if o.startswith("http://") or o.startswith("https://") or o.startswith("urn:"): + request["o"] = {"t": "i", "i": o} + else: + request["o"] = {"t": "l", "v": o} + elif isinstance(o, dict): + request["o"] = o + if g is not None: + request["g"] = g + + triples = [] + try: + for response in socket._send_request_sync("triples", flow_id, request, streaming_raw=True): + if isinstance(response, dict): + triple_list = response.get("response", response.get("triples", [])) + else: + triple_list = response + + if not isinstance(triple_list, list): + triple_list = [triple_list] if triple_list else [] + + for t in triple_list: + s_val = extract_value(t.get("s", {})) + p_val = extract_value(t.get("p", {})) + o_val = extract_value(t.get("o", {})) + triples.append((s_val, p_val, o_val)) + except Exception as e: + print(f"Error querying triples: {e}", file=sys.stderr) + + return triples + + +def extract_value(term): + """Extract value from a term dict.""" + if not term: + return "" + + t = term.get("t") or term.get("type") + + if t == "i": + return term.get("i") or term.get("iri", "") + elif t == "l": + return term.get("v") or term.get("value", "") + elif t == "t": + # Quoted triple + tr = term.get("tr") or term.get("triple", {}) + return { + "s": extract_value(tr.get("s", {})), + "p": extract_value(tr.get("p", {})), + "o": extract_value(tr.get("o", {})), + } + + # Fallback for raw values + if "i" in term: + return term["i"] + if "v" in term: + return term["v"] + + return str(term) + + +def get_node_metadata(socket, flow_id, user, collection, node_uri): + """Get metadata for a node (label, types, title, format).""" + triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=SOURCE_GRAPH) + + metadata = {"uri": node_uri, "types": []} + for s, p, o in triples: + if p == RDFS_LABEL: + metadata["label"] = o + elif p == RDF_TYPE: + metadata["types"].append(o) + elif p == DC_TITLE: + metadata["title"] = o + elif p == DC_FORMAT: + metadata["format"] = o + + return metadata + + +def classify_node(metadata): + """Classify a node based on its rdf:type values.""" + for type_uri in metadata.get("types", []): + if type_uri in TYPE_MAP: + return TYPE_MAP[type_uri] + return "unknown" + + +def get_children(socket, flow_id, user, collection, parent_uri): + """Get children of a node via prov:wasDerivedFrom.""" + triples = query_triples( + socket, flow_id, user, collection, + p=PROV_WAS_DERIVED_FROM, o=parent_uri, g=SOURCE_GRAPH + ) + return [s for s, p, o in triples] + + +def get_document_content(api, user, doc_id, max_content): + """Fetch document content from librarian API.""" + try: + library = api.library() + content = library.get_document_content(user=user, id=doc_id) + + # Try to decode as text + try: + text = content.decode('utf-8') + if len(text) > max_content: + return text[:max_content] + "... [truncated]" + return text + except UnicodeDecodeError: + return f"[Binary: {len(content)} bytes]" + except Exception as e: + return f"[Error fetching content: {e}]" + + +def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_content=False, max_content=200, visited=None): + """Build document hierarchy tree recursively.""" + if visited is None: + visited = set() + + if root_uri in visited: + return None + visited.add(root_uri) + + metadata = get_node_metadata(socket, flow_id, user, collection, root_uri) + node_type = classify_node(metadata) + + node = { + "uri": root_uri, + "type": node_type, + "metadata": metadata, + "children": [], + "edges": [], + } + + # Fetch content if requested + if show_content and api: + content = get_document_content(api, user, root_uri, max_content) + if content: + node["content"] = content + + # Get children + children_uris = get_children(socket, flow_id, user, collection, root_uri) + + for child_uri in children_uris: + child_metadata = get_node_metadata(socket, flow_id, user, collection, child_uri) + child_type = classify_node(child_metadata) + + if child_type == "subgraph": + # Subgraphs contain extracted edges — inline them + contains_triples = query_triples( + socket, flow_id, user, collection, + s=child_uri, p=TG_CONTAINS, g=SOURCE_GRAPH + ) + for _, _, edge in contains_triples: + if isinstance(edge, dict): + node["edges"].append(edge) + else: + # Recurse into pages, chunks, etc. + child_node = build_hierarchy( + socket, flow_id, user, collection, child_uri, + api=api, show_content=show_content, max_content=max_content, + visited=visited + ) + if child_node: + node["children"].append(child_node) + + # Sort children by URI for consistent output + node["children"].sort(key=lambda x: x.get("uri", "")) + + return node + + +def format_edge(edge): + """Format an edge (quoted triple) for display.""" + if isinstance(edge, dict): + s = edge.get("s", "?") + p = edge.get("p", "?") + o = edge.get("o", "?") + + # Shorten URIs for display + s_short = s.split("/")[-1] if "/" in str(s) else s + p_short = p.split("/")[-1] if "/" in str(p) else p + o_short = o.split("/")[-1] if "/" in str(o) else o + + return f"({s_short}, {p_short}, {o_short})" + return str(edge) + + +def print_tree(node, prefix="", is_last=True, show_content=False): + """Print node as indented tree.""" + connector = "└── " if is_last else "├── " + continuation = " " if is_last else "│ " + + # Format node header + uri = node.get("uri", "") + node_type = node.get("type", "unknown") + metadata = node.get("metadata", {}) + + label = metadata.get("label") or metadata.get("title") or uri.split("/")[-1] + type_str = node_type.capitalize() + + if prefix: + print(f"{prefix}{connector}{type_str}: {label}") + else: + print(f"{type_str}: {uri}") + if metadata.get("title"): + print(f" Title: \"{metadata['title']}\"") + if metadata.get("format"): + print(f" Type: {metadata['format']}") + + new_prefix = prefix + continuation if prefix else " " + + # Print content if available + if show_content and "content" in node: + content = node["content"] + content_lines = content.split("\n")[:3] # Show first 3 lines + for line in content_lines: + if line.strip(): + truncated = line[:80] + "..." if len(line) > 80 else line + print(f"{new_prefix}Content: \"{truncated}\"") + break + + # Print edges + edges = node.get("edges", []) + children = node.get("children", []) + + total_items = len(edges) + len(children) + current_item = 0 + + for edge in edges: + current_item += 1 + is_last_item = (current_item == total_items) + edge_connector = "└── " if is_last_item else "├── " + print(f"{new_prefix}{edge_connector}Edge: {format_edge(edge)}") + + # Print children recursively + for i, child in enumerate(children): + current_item += 1 + is_last_child = (i == len(children) - 1) + print_tree(child, new_prefix, is_last_child, show_content) + + +def print_json(node): + """Print node as JSON.""" + print(json.dumps(node, indent=2)) + + +def main(): + parser = argparse.ArgumentParser( + prog='tg-show-extraction-provenance', + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + 'document_id', + help='Document URI to show hierarchy for', + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Auth token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})', + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection (default: {default_collection})', + ) + + parser.add_argument( + '-f', '--flow-id', + default='default', + help='Flow ID (default: default)', + ) + + parser.add_argument( + '--show-content', + action='store_true', + help='Include blob/document content', + ) + + parser.add_argument( + '--max-content', + type=int, + default=200, + help='Max chars to display per blob (default: 200)', + ) + + parser.add_argument( + '--format', + choices=['tree', 'json'], + default='tree', + help='Output format: tree (default), json', + ) + + args = parser.parse_args() + + try: + api = Api(args.api_url, token=args.token) + socket = api.socket() + + try: + hierarchy = build_hierarchy( + socket=socket, + flow_id=args.flow_id, + user=args.user, + collection=args.collection, + root_uri=args.document_id, + api=api if args.show_content else None, + show_content=args.show_content, + max_content=args.max_content, + ) + + if hierarchy is None: + print(f"No data found for document: {args.document_id}", file=sys.stderr) + sys.exit(1) + + if args.format == 'json': + print_json(hierarchy) + else: + print_tree(hierarchy, show_content=args.show_content) + + finally: + socket.close() + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index b5b15e3c..8db4edf4 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -1,5 +1,11 @@ """ Connects to the graph query service and dumps all graph edges. +Uses streaming mode for lower time-to-first-result and reduced memory overhead. + +Named graphs: + - Default graph (empty): Core knowledge facts + - urn:graph:source: Extraction provenance (document/chunk sources) + - urn:graph:retrieval: Query-time explainability (question, exploration, focus, synthesis) """ import argparse @@ -11,17 +17,42 @@ default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_graph(url, flow_id, user, collection, token=None): +# Named graph constants for convenience +GRAPH_DEFAULT = "" +GRAPH_SOURCE = "urn:graph:source" +GRAPH_RETRIEVAL = "urn:graph:retrieval" - api = Api(url, token=token).flow().id(flow_id) - rows = api.triples_query( - user=user, collection=collection, - s=None, p=None, o=None, limit=10_000, - ) +def show_graph(url, flow_id, user, collection, limit, batch_size, graph=None, show_graph_column=False, token=None): - for row in rows: - print(row.s, row.p, row.o) + socket = Api(url, token=token).socket() + flow = socket.flow(flow_id) + + try: + for batch in flow.triples_query_stream( + user=user, + collection=collection, + s=None, p=None, o=None, + g=graph, # Filter by named graph (None = all graphs) + limit=limit, + batch_size=batch_size, + ): + for triple in batch: + s = triple.get("s", {}) + p = triple.get("p", {}) + o = triple.get("o", {}) + g = triple.get("g") # Named graph (None = default graph) + # Format terms for display + s_str = s.get("v", s.get("i", str(s))) + p_str = p.get("v", p.get("i", str(p))) + o_str = o.get("v", o.get("i", str(o))) + if show_graph_column: + g_str = g if g else "(default)" + print(f"[{g_str}]", s_str, p_str, o_str) + else: + print(s_str, p_str, o_str) + finally: + socket.close() def main(): @@ -60,8 +91,39 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Maximum number of triples to return (default: 10000)', + ) + + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Triples per streaming batch (default: 20)', + ) + + parser.add_argument( + '-g', '--graph', + default=None, + help='Filter by named graph (e.g., urn:graph:source, urn:graph:retrieval). Use "" for default graph only.', + ) + + parser.add_argument( + '--show-graph', + action='store_true', + help='Show graph column in output', + ) + args = parser.parse_args() + # Handle empty string for default graph filter + graph = args.graph + if graph == '""' or graph == "''": + graph = "" # Filter to default graph only + try: show_graph( @@ -69,6 +131,10 @@ def main(): flow_id = args.flow_id, user = args.user, collection = args.collection, + limit = args.limit, + batch_size = args.batch_size, + graph = graph, + show_graph_column = args.show_graph, token = args.token, ) diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 79e14540..9eaedf0a 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", - "trustgraph-flow>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", + "trustgraph-flow>=2.1,<2.2", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 31a22a2f..dc131a9e 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", "aiohttp", "anthropic", "scylla-driver", @@ -27,7 +27,7 @@ dependencies = [ "langchain-text-splitters", "mcp", "minio", - "mistralai", + "mistralai<2.0.0", "neo4j", "nltk", "ollama", @@ -122,6 +122,7 @@ triples-write-falkordb = "trustgraph.storage.triples.falkordb:run" triples-write-memgraph = "trustgraph.storage.triples.memgraph:run" triples-write-neo4j = "trustgraph.storage.triples.neo4j:run" wikipedia-lookup = "trustgraph.external.wikipedia:run" +joke-service = "trustgraph.tool_service.joke:run" [tool.setuptools.packages.find] include = ["trustgraph*"] diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 1a44ef9e..9a02e5c6 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -2,11 +2,15 @@ Simple agent infrastructure broadly implements the ReAct flow. """ +import asyncio +import base64 import json import re import sys import functools import logging +import uuid +from datetime import datetime # Module logger logger = logging.getLogger(__name__) @@ -14,10 +18,30 @@ logger = logging.getLogger(__name__) from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec +from ... base import ProducerSpec +from ... base import Consumer, Producer +from ... base import ConsumerMetrics, ProducerMetrics from ... schema import AgentRequest, AgentResponse, AgentStep, Error +from ... schema import Triples, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue -from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl +# Provenance imports for agent explainability +from trustgraph.provenance import ( + agent_session_uri, + agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, + agent_final_uri, + agent_session_triples, + agent_iteration_triples, + agent_final_triples, + set_graph, + GRAPH_RETRIEVAL, +) + +from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl from . agent_manager import AgentManager from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state @@ -25,6 +49,8 @@ from . types import Final, Action, Tool, Argument default_ident = "agent-manager" default_max_iterations = 10 +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(AgentService): @@ -51,6 +77,9 @@ class Processor(AgentService): additional_context="", ) + # Track active tool service clients for cleanup + self.tool_service_clients = {} + self.config_handlers.append(self.on_tools_config) self.register_specification( @@ -102,6 +131,123 @@ class Processor(AgentService): ) ) + # Explainability producer for agent provenance triples + self.register_specification( + ProducerSpec( + name = "explainability", + schema = Triples, + ) + ) + + # Librarian client for storing answer content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_librarian_requests = {} + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id in self.pending_librarian_requests: + future = self.pending_librarian_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "Agent Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_librarian_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_librarian_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_tools_config(self, config, version): logger.info(f"Loading configuration version {version}") @@ -110,6 +256,16 @@ class Processor(AgentService): tools = {} + # Load tool-service configurations first + tool_services = {} + if "tool-service" in config: + for service_id, service_value in config["tool-service"].items(): + service_data = json.loads(service_value) + tool_services[service_id] = service_data + logger.debug(f"Loaded tool-service config: {service_id}") + + logger.info(f"Loaded {len(tool_services)} tool-service configurations") + # Load tool configurations from the new location if "tool" in config: for tool_id, tool_value in config["tool"].items(): @@ -177,6 +333,59 @@ class Processor(AgentService): limit=int(data.get("limit", 10)) # Max results ) arguments = RowEmbeddingsQueryImpl.get_arguments() + elif impl_id == "tool-service": + # Dynamic tool service - look up the service config + service_ref = data.get("service") + if not service_ref: + raise RuntimeError( + f"Tool {name} has type 'tool-service' but no 'service' reference" + ) + if service_ref not in tool_services: + raise RuntimeError( + f"Tool {name} references unknown tool-service '{service_ref}'" + ) + + service_config = tool_services[service_ref] + request_queue = service_config.get("request-queue") + response_queue = service_config.get("response-queue") + if not request_queue or not response_queue: + raise RuntimeError( + f"Tool-service '{service_ref}' must define 'request-queue' and 'response-queue'" + ) + + # Build config values from tool config + # Extract any config params defined by the service + config_params = service_config.get("config-params", []) + config_values = {} + for param in config_params: + param_name = param.get("name") if isinstance(param, dict) else param + if param_name in data: + config_values[param_name] = data[param_name] + elif isinstance(param, dict) and param.get("required", False): + raise RuntimeError( + f"Tool {name} missing required config param '{param_name}'" + ) + + # Arguments come from tool config + config_args = data.get("arguments", []) + arguments = [ + Argument( + name=arg.get("name"), + type=arg.get("type"), + description=arg.get("description") + ) + for arg in config_args + ] + + # Store queues for the implementation + impl = functools.partial( + ToolServiceImpl, + request_queue=request_queue, + response_queue=response_queue, + config_values=config_values, + arguments=arguments, + processor=self, + ) else: raise RuntimeError( f"Tool type {impl_id} not known" @@ -219,6 +428,10 @@ class Processor(AgentService): # Check if streaming is enabled streaming = getattr(request, 'streaming', False) + # Generate or retrieve session ID for provenance tracking + session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) + collection = getattr(request, 'collection', 'default') + if request.history: history = [ Action( @@ -232,6 +445,36 @@ class Processor(AgentService): else: history = [] + # Calculate iteration number (1-based) + iteration_num = len(history) + 1 + session_uri = agent_session_uri(session_id) + + # On first iteration, emit session triples + if iteration_num == 1: + timestamp = datetime.utcnow().isoformat() + "Z" + triples = set_graph( + agent_session_triples(session_uri, request.question, timestamp), + GRAPH_RETRIEVAL + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=session_uri, + user=request.user, + collection=collection, + ), + triples=triples, + )) + logger.debug(f"Emitted session triples for {session_uri}") + + # Send explain event for session + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=session_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + logger.info(f"Question: {request.question}") if len(history) >= self.max_iterations: @@ -381,6 +624,60 @@ class Processor(AgentService): else: f = json.dumps(act.final) + # Emit final answer provenance triples + final_uri = agent_final_uri(session_id) + # No iterations: link to question; otherwise: link to last iteration + if iteration_num > 1: + final_question_uri = None + final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + else: + final_question_uri = session_uri + final_previous_uri = None + + # Save answer to librarian + answer_doc_id = None + if f: + answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" + try: + await self.save_answer_content( + doc_id=answer_doc_id, + user=request.user, + content=f, + title=f"Agent Answer: {request.question[:50]}...", + ) + logger.debug(f"Saved answer to librarian: {answer_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + answer_doc_id = None # Fall back to inline content + + final_triples = set_graph( + agent_final_triples( + final_uri, + question_uri=final_question_uri, + previous_uri=final_previous_uri, + document_id=answer_doc_id, + ), + GRAPH_RETRIEVAL + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=final_uri, + user=request.user, + collection=collection, + ), + triples=final_triples, + )) + logger.debug(f"Emitted final triples for {final_uri}") + + # Send explain event for conclusion + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=final_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + if streaming: # Streaming format - send end-of-dialog marker # Answer chunks were already sent via answer() callback during parsing @@ -413,8 +710,86 @@ class Processor(AgentService): logger.debug("Send next...") + # Emit iteration provenance triples + iteration_uri = agent_iteration_uri(session_id, iteration_num) + # First iteration links to question, subsequent to previous + if iteration_num > 1: + iter_question_uri = None + iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + else: + iter_question_uri = session_uri + iter_previous_uri = None + + # Save thought to librarian + thought_doc_id = None + if act.thought: + thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + try: + await self.save_answer_content( + doc_id=thought_doc_id, + user=request.user, + content=act.thought, + title=f"Agent Thought: {act.name}", + ) + logger.debug(f"Saved thought to librarian: {thought_doc_id}") + except Exception as e: + logger.warning(f"Failed to save thought to librarian: {e}") + thought_doc_id = None + + # Save observation to librarian + observation_doc_id = None + if act.observation: + observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + try: + await self.save_answer_content( + doc_id=observation_doc_id, + user=request.user, + content=act.observation, + title=f"Agent Observation: {act.name}", + ) + logger.debug(f"Saved observation to librarian: {observation_doc_id}") + except Exception as e: + logger.warning(f"Failed to save observation to librarian: {e}") + observation_doc_id = None + + thought_entity_uri = agent_thought_uri(session_id, iteration_num) + observation_entity_uri = agent_observation_uri(session_id, iteration_num) + + iter_triples = set_graph( + agent_iteration_triples( + iteration_uri, + question_uri=iter_question_uri, + previous_uri=iter_previous_uri, + action=act.name, + arguments=act.arguments, + thought_uri=thought_entity_uri if thought_doc_id else None, + thought_document_id=thought_doc_id, + observation_uri=observation_entity_uri if observation_doc_id else None, + observation_document_id=observation_doc_id, + ), + GRAPH_RETRIEVAL + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=iteration_uri, + user=request.user, + collection=collection, + ), + triples=iter_triples, + )) + logger.debug(f"Emitted iteration triples for {iteration_uri}") + + # Send explain event for iteration + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=iteration_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + history.append(act) - + # Handle state transitions if tool execution was successful next_state = request.state if act.name in filtered_tools: @@ -435,7 +810,9 @@ class Processor(AgentService): for h in history ], user=request.user, + collection=collection, streaming=streaming, + session_id=session_id, # Pass session_id for provenance continuity ) await next(r) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 2b442a0d..441c8f38 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl: logger.debug("Getting embeddings for row query...") query_text = arguments.get("query") - vectors = await embeddings_client.embed(query_text) + all_vectors = await embeddings_client.embed([query_text]) + vector = all_vectors[0] if all_vectors else [] # Now query row embeddings client = self.context("row-embeddings-query-request") @@ -164,7 +165,7 @@ class RowEmbeddingsQueryImpl: user = getattr(client, '_current_user', self.user or "trustgraph") matches = await client.row_embeddings_query( - vectors=vectors, + vector=vector, schema_name=self.schema_name, user=user, collection=self.collection or "default", @@ -202,3 +203,116 @@ class PromptImpl: id=self.template_id, variables=arguments ) + + +# This tool implementation invokes a dynamically configured tool service +class ToolServiceImpl: + """ + Implementation for dynamically pluggable tool services. + + Tool services are external Pulsar services that can be invoked as agent tools. + The service is configured via a tool-service descriptor that defines the queues, + and a tool descriptor that provides config values and argument definitions. + """ + + def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None): + """ + Initialize a tool service implementation. + + Args: + context: The context function (provides user info) + request_queue: Full Pulsar topic for requests + response_queue: Full Pulsar topic for responses + config_values: Dict of config values (e.g., {"collection": "customers"}) + arguments: List of Argument objects defining the tool's parameters + processor: The Processor instance (for pubsub access) + """ + self.context = context + self.request_queue = request_queue + self.response_queue = response_queue + self.config_values = config_values or {} + self.arguments = arguments or [] + self.processor = processor + self._client = None + + def get_arguments(self): + return self.arguments + + async def _get_or_create_client(self): + """Get or create the tool service client.""" + if self._client is not None: + return self._client + + # Check if processor already has a client for this queue pair + client_key = f"{self.request_queue}|{self.response_queue}" + if client_key in self.processor.tool_service_clients: + self._client = self.processor.tool_service_clients[client_key] + return self._client + + # Import here to avoid circular imports + from trustgraph.base.tool_service_client import ToolServiceClient + from trustgraph.base.metrics import ProducerMetrics, SubscriberMetrics + from trustgraph.schema import ToolServiceRequest, ToolServiceResponse + import uuid + + request_metrics = ProducerMetrics( + processor=self.processor.id, + flow="tool-service", + name=self.request_queue + ) + response_metrics = SubscriberMetrics( + processor=self.processor.id, + flow="tool-service", + name=self.response_queue + ) + + # Create unique subscription for responses + subscription = f"{self.processor.id}--tool-service--{uuid.uuid4()}" + + self._client = ToolServiceClient( + backend=self.processor.pubsub, + subscription=subscription, + consumer_name=self.processor.id, + request_topic=self.request_queue, + request_schema=ToolServiceRequest, + request_metrics=request_metrics, + response_topic=self.response_queue, + response_schema=ToolServiceResponse, + response_metrics=response_metrics, + ) + + # Start the client + await self._client.start() + + # Register for cleanup + self.processor.tool_service_clients[client_key] = self._client + + logger.debug(f"Created tool service client for {self.request_queue}") + return self._client + + async def invoke(self, **arguments): + logger.debug(f"Tool service invocation: {self.request_queue}...") + logger.debug(f"Config: {self.config_values}") + logger.debug(f"Arguments: {arguments}") + + # Get user from context if available + user = "trustgraph" + if hasattr(self.context, '_user'): + user = self.context._user + + # Get or create the client + client = await self._get_or_create_client() + + # Call the tool service + response = await client.call( + user=user, + config=self.config_values, + arguments=arguments, + ) + + logger.debug(f"Tool service response: {response}") + + if isinstance(response, str): + return response + else: + return json.dumps(response) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index bc6d9cb9..fb84c356 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -8,14 +8,24 @@ import logging from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk +from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec +from ... provenance import ( + derived_entity_triples, + set_graph, GRAPH_SOURCE, +) + +# Component identification for provenance +COMPONENT_NAME = "chunker" +COMPONENT_VERSION = "1.0.0" + # Module logger logger = logging.getLogger(__name__) default_ident = "chunker" + class Processor(ChunkingService): def __init__(self, **params): @@ -23,7 +33,7 @@ class Processor(ChunkingService): id = params.get("id", default_ident) chunk_size = params.get("chunk_size", 2000) chunk_overlap = params.get("chunk_overlap", 100) - + super(Processor, self).__init__( **params | { "id": id } ) @@ -62,6 +72,13 @@ class Processor(ChunkingService): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + logger.info("Recursive chunker initialized") async def on_message(self, msg, consumer, flow): @@ -69,6 +86,9 @@ class Processor(ChunkingService): v = msg.value() logger.info(f"Chunking document {v.metadata.id}...") + # Get text content (fetches from librarian if needed) + text = await self.get_document_text(v) + # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( msg, consumer, flow, @@ -90,25 +110,84 @@ class Processor(ChunkingService): is_separator_regex=False, ) - texts = text_splitter.create_documents( - [v.text.decode("utf-8")] - ) + texts = text_splitter.create_documents([text]) + + # Get parent document ID for provenance linking + # This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it + parent_doc_id = v.document_id or v.metadata.id + + # Track character offset for provenance + char_offset = 0 for ix, chunk in enumerate(texts): + chunk_index = ix + 1 # 1-indexed logger.debug(f"Created chunk of size {len(chunk.page_content)}") + # Generate chunk document ID by appending /c{index} to parent + # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) + chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" + chunk_uri = chunk_doc_id # URI is same as document ID + parent_uri = parent_doc_id + + chunk_content = chunk.page_content.encode("utf-8") + chunk_length = len(chunk.page_content) + + # Save chunk to librarian as child document + await self.save_child_document( + doc_id=chunk_doc_id, + parent_id=parent_doc_id, + user=v.metadata.user, + content=chunk_content, + document_type="chunk", + title=f"Chunk {chunk_index}", + ) + + # Emit provenance triples (stored in source graph for separation from core knowledge) + prov_triples = derived_entity_triples( + entity_uri=chunk_uri, + parent_uri=parent_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Chunk {chunk_index}", + chunk_index=chunk_index, + char_offset=char_offset, + char_length=chunk_length, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward chunk ID + content (post-chunker optimization) r = Chunk( - metadata=v.metadata, - chunk=chunk.page_content.encode("utf-8"), + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + chunk=chunk_content, + document_id=chunk_doc_id, ) __class__.chunk_metric.labels( id=consumer.id, flow=consumer.flow - ).observe(len(chunk.page_content)) + ).observe(chunk_length) await flow("output").send(r) + # Update character offset (approximate, doesn't account for overlap) + char_offset += chunk_length - chunk_overlap + logger.debug("Document chunking complete") @staticmethod @@ -133,4 +212,3 @@ class Processor(ChunkingService): def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 876cab07..909396c6 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -8,14 +8,24 @@ import logging from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk +from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec +from ... provenance import ( + derived_entity_triples, + set_graph, GRAPH_SOURCE, +) + +# Component identification for provenance +COMPONENT_NAME = "token-chunker" +COMPONENT_VERSION = "1.0.0" + # Module logger logger = logging.getLogger(__name__) default_ident = "chunker" + class Processor(ChunkingService): def __init__(self, **params): @@ -23,7 +33,7 @@ class Processor(ChunkingService): id = params.get("id", default_ident) chunk_size = params.get("chunk_size", 250) chunk_overlap = params.get("chunk_overlap", 15) - + super(Processor, self).__init__( **params | { "id": id } ) @@ -61,6 +71,13 @@ class Processor(ChunkingService): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + logger.info("Token chunker initialized") async def on_message(self, msg, consumer, flow): @@ -68,6 +85,9 @@ class Processor(ChunkingService): v = msg.value() logger.info(f"Chunking document {v.metadata.id}...") + # Get text content (fetches from librarian if needed) + text = await self.get_document_text(v) + # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( msg, consumer, flow, @@ -88,25 +108,84 @@ class Processor(ChunkingService): chunk_overlap=chunk_overlap, ) - texts = text_splitter.create_documents( - [v.text.decode("utf-8")] - ) + texts = text_splitter.create_documents([text]) + + # Get parent document ID for provenance linking + # This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it + parent_doc_id = v.document_id or v.metadata.id + + # Track token offset for provenance (approximate) + token_offset = 0 for ix, chunk in enumerate(texts): + chunk_index = ix + 1 # 1-indexed logger.debug(f"Created chunk of size {len(chunk.page_content)}") + # Generate chunk document ID by appending /c{index} to parent + # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) + chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" + chunk_uri = chunk_doc_id # URI is same as document ID + parent_uri = parent_doc_id + + chunk_content = chunk.page_content.encode("utf-8") + chunk_length = len(chunk.page_content) + + # Save chunk to librarian as child document + await self.save_child_document( + doc_id=chunk_doc_id, + parent_id=parent_doc_id, + user=v.metadata.user, + content=chunk_content, + document_type="chunk", + title=f"Chunk {chunk_index}", + ) + + # Emit provenance triples (stored in source graph for separation from core knowledge) + prov_triples = derived_entity_triples( + entity_uri=chunk_uri, + parent_uri=parent_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Chunk {chunk_index}", + chunk_index=chunk_index, + char_offset=token_offset, # Note: this is token offset, not char offset + char_length=chunk_length, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward chunk ID + content (post-chunker optimization) r = Chunk( - metadata=v.metadata, - chunk=chunk.page_content.encode("utf-8"), + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + chunk=chunk_content, + document_id=chunk_doc_id, ) __class__.chunk_metric.labels( id=consumer.id, flow=consumer.flow - ).observe(len(chunk.page_content)) + ).observe(chunk_length) await flow("output").send(r) + # Update token offset (approximate, doesn't account for overlap) + token_offset += chunk_size - chunk_overlap + logger.debug("Document chunking complete") @staticmethod @@ -118,17 +197,16 @@ class Processor(ChunkingService): '-z', '--chunk-size', type=int, default=250, - help=f'Chunk size (default: 250)' + help=f'Chunk size in tokens (default: 250)' ) parser.add_argument( '-v', '--chunk-overlap', type=int, default=15, - help=f'Chunk overlap (default: 15)' + help=f'Chunk overlap in tokens (default: 15)' ) def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index bb641a26..550948fe 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -2,21 +2,44 @@ """ Simple decoder, accepts PDF documents on input, outputs pages from the PDF document as text as separate output objects. + +Supports both inline document data and fetching from librarian via Pulsar +for large documents. """ +import asyncio +import os import tempfile import base64 import logging +import uuid from langchain_community.document_loaders import PyPDFLoader from ... schema import Document, TextDocument, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue +from ... schema import Triples from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics + +from ... provenance import ( + document_uri, page_uri, derived_entity_triples, + set_graph, GRAPH_SOURCE, +) + +# Component identification for provenance +COMPONENT_NAME = "pdf-decoder" +COMPONENT_VERSION = "1.0.0" # Module logger logger = logging.getLogger(__name__) default_ident = "pdf-decoder" +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue + + class Processor(FlowProcessor): def __init__(self, **params): @@ -44,8 +67,164 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + + # Librarian client for fetching document content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor = id, flow = None, name = "librarian-request" + ) + + self.librarian_request_producer = Producer( + backend = self.pubsub, + topic = librarian_request_q, + schema = LibrarianRequest, + metrics = librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor = id, flow = None, name = "librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup = self.taskgroup, + backend = self.pubsub, + flow = None, + topic = librarian_response_q, + subscriber = f"{id}-librarian", + schema = LibrarianResponse, + handler = self.on_librarian_response, + metrics = librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + logger.info("PDF decoder initialized") + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_document_content(self, document_id, user, timeout=120): + """ + Fetch document content from librarian via Pulsar. + """ + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=document_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + return response.content + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching document {document_id}") + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="page", title=None, timeout=120): + """ + Save a child document to the librarian. + + Args: + doc_id: ID for the new child document + parent_id: ID of the parent document + user: User ID + content: Document content (bytes) + document_type: Type of document ("page", "chunk", etc.) + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + import base64 + + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + request = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving child document: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving child document {doc_id}") + async def on_message(self, msg, consumer, flow): logger.debug("PDF message received") @@ -54,26 +233,102 @@ class Processor(FlowProcessor): logger.info(f"Decoding PDF {v.metadata.id}...") - with tempfile.NamedTemporaryFile(delete_on_close=False) as fp: + with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp: + temp_path = fp.name - fp.write(base64.b64decode(v.data)) - fp.close() + # Check if we should fetch from librarian or use inline data + if v.document_id: + # Fetch from librarian via Pulsar + logger.info(f"Fetching document {v.document_id} from librarian...") + fp.close() - with open(fp.name, mode='rb') as f: + content = await self.fetch_document_content( + document_id=v.document_id, + user=v.metadata.user, + ) - loader = PyPDFLoader(fp.name) - pages = loader.load() + # Content is base64 encoded + if isinstance(content, str): + content = content.encode('utf-8') + decoded_content = base64.b64decode(content) - for ix, page in enumerate(pages): + with open(temp_path, 'wb') as f: + f.write(decoded_content) - logger.debug(f"Processing page {ix}") + logger.info(f"Fetched {len(decoded_content)} bytes from librarian") + else: + # Use inline data (backward compatibility) + fp.write(base64.b64decode(v.data)) + fp.close() - r = TextDocument( - metadata=v.metadata, - text=page.page_content.encode("utf-8"), - ) + loader = PyPDFLoader(temp_path) + pages = loader.load() - await flow("output").send(r) + # Get the source document ID + source_doc_id = v.document_id or v.metadata.id + + for ix, page in enumerate(pages): + page_num = ix + 1 # 1-indexed page numbers + + logger.debug(f"Processing page {page_num}") + + # Generate page document ID + page_doc_id = f"{source_doc_id}/p{page_num}" + page_content = page.page_content.encode("utf-8") + + # Save page as child document in librarian + await self.save_child_document( + doc_id=page_doc_id, + parent_id=source_doc_id, + user=v.metadata.user, + content=page_content, + document_type="page", + title=f"Page {page_num}", + ) + + # Emit provenance triples (stored in source graph for separation from core knowledge) + doc_uri = document_uri(source_doc_id) + pg_uri = page_uri(source_doc_id, page_num) + + prov_triples = derived_entity_triples( + entity_uri=pg_uri, + parent_uri=doc_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Page {page_num}", + page_number=page_num, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward page document ID to chunker + # Chunker will fetch content from librarian + r = TextDocument( + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + document_id=page_doc_id, + text=b"", # Empty, chunker will fetch from librarian + ) + + await flow("output").send(r) + + # Clean up temp file + try: + os.unlink(temp_path) + except OSError: + pass logger.debug("PDF decoding complete") @@ -81,7 +336,18 @@ class Processor(FlowProcessor): def add_args(parser): FlowProcessor.add_args(parser) + parser.add_argument( + '--librarian-request-queue', + default=default_librarian_request_queue, + help=f'Librarian request queue (default: {default_librarian_request_queue})', + ) + + parser.add_argument( + '--librarian-response-queue', + default=default_librarian_response_queue, + help=f'Librarian response queue (default: {default_librarian_response_queue})', + ) + def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 61639096..59d2a2a1 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -589,6 +589,8 @@ class EntityCentricKnowledgeGraph: # quads_by_entity: primary data table # Every entity has a partition containing all quads it participates in + # Clustering key includes dtype/lang to distinguish literals with same value + # but different datatype or language tag (e.g., "thing" vs "thing"@en) self.session.execute(f""" CREATE TABLE IF NOT EXISTS {self.entity_table} ( collection text, @@ -601,11 +603,13 @@ class EntityCentricKnowledgeGraph: d text, dtype text, lang text, - PRIMARY KEY ((collection, entity), role, p, otype, s, o, d) + PRIMARY KEY ((collection, entity), role, p, otype, s, o, d, dtype, lang) ); """) # quads_by_collection: manifest for collection-level queries and deletion + # Clustering key includes otype/dtype/lang to distinguish literals with same + # value but different metadata (e.g., "thing" vs "thing"@en vs "thing"^^xsd:string) self.session.execute(f""" CREATE TABLE IF NOT EXISTS {self.collection_table} ( collection text, @@ -616,7 +620,7 @@ class EntityCentricKnowledgeGraph: otype text, dtype text, lang text, - PRIMARY KEY (collection, d, s, p, o) + PRIMARY KEY (collection, d, s, p, o, otype, dtype, lang) ); """) @@ -718,7 +722,7 @@ class EntityCentricKnowledgeGraph: ) self.delete_collection_row_stmt = self.session.prepare( - f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ?" + f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ? AND otype = ? AND dtype = ? AND lang = ?" ) logger.info("Prepared statements initialized for entity-centric schema") @@ -797,7 +801,7 @@ class EntityCentricKnowledgeGraph: def get_s(self, collection, s, g=None, limit=10): """ Query by subject. Returns quads where s is the subject. - g=None: default graph, g='*': all graphs + g=None: all graphs, g='': default graph only, g='uri': specific graph """ rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit)) @@ -805,10 +809,7 @@ class EntityCentricKnowledgeGraph: for row in rows: d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH # Filter by graph if specified - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -819,16 +820,13 @@ class EntityCentricKnowledgeGraph: return results def get_p(self, collection, p, g=None, limit=10): - """Query by predicate""" + """Query by predicate. g=None: all graphs, g='': default graph only""" rows = self.session.execute(self.get_entity_as_p_stmt, (collection, p, limit)) results = [] for row in rows: d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -839,16 +837,13 @@ class EntityCentricKnowledgeGraph: return results def get_o(self, collection, o, g=None, limit=10): - """Query by object""" + """Query by object. g=None: all graphs, g='': default graph only""" rows = self.session.execute(self.get_entity_as_o_stmt, (collection, o, limit)) results = [] for row in rows: d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -859,16 +854,13 @@ class EntityCentricKnowledgeGraph: return results def get_sp(self, collection, s, p, g=None, limit=10): - """Query by subject and predicate""" + """Query by subject and predicate. g=None: all graphs, g='': default graph only""" rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit)) results = [] for row in rows: d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -879,16 +871,13 @@ class EntityCentricKnowledgeGraph: return results def get_po(self, collection, p, o, g=None, limit=10): - """Query by predicate and object""" + """Query by predicate and object. g=None: all graphs, g='': default graph only""" rows = self.session.execute(self.get_entity_as_o_p_stmt, (collection, o, p, limit)) results = [] for row in rows: d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -899,7 +888,7 @@ class EntityCentricKnowledgeGraph: return results def get_os(self, collection, o, s, g=None, limit=10): - """Query by object and subject""" + """Query by object and subject. g=None: all graphs, g='': default graph only""" # Use subject partition with role='S', filter by o rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit)) @@ -909,10 +898,7 @@ class EntityCentricKnowledgeGraph: continue d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -923,7 +909,7 @@ class EntityCentricKnowledgeGraph: return results def get_spo(self, collection, s, p, o, g=None, limit=10): - """Query by subject, predicate, object (find which graphs)""" + """Query by subject, predicate, object (find which graphs). g=None: all graphs, g='': default graph only""" rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit)) results = [] @@ -932,10 +918,7 @@ class EntityCentricKnowledgeGraph: continue d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH - if g is None or g == DEFAULT_GRAPH: - if d != DEFAULT_GRAPH: - continue - elif g != GRAPH_WILDCARD and d != g: + if g is not None and d != g: continue results.append(QuadResult( @@ -991,9 +974,9 @@ class EntityCentricKnowledgeGraph: 3. Delete entire entity partitions 4. Delete collection rows """ - # Read all quads from collection table + # Read all quads from collection table (including type metadata for delete) rows = self.session.execute( - f"SELECT d, s, p, o, otype FROM {self.collection_table} WHERE collection = %s", + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s", (collection,) ) @@ -1002,8 +985,11 @@ class EntityCentricKnowledgeGraph: quads = [] for row in rows: - d, s, p, o, otype = row.d, row.s, row.p, row.o, row.otype - quads.append((d, s, p, o)) + d, s, p, o = row.d, row.s, row.p, row.o + otype = row.otype + dtype = row.dtype if hasattr(row, 'dtype') else '' + lang = row.lang if hasattr(row, 'lang') else '' + quads.append((d, s, p, o, otype, dtype, lang)) # Subject and predicate are always entities entities.add(s) @@ -1038,8 +1024,8 @@ class EntityCentricKnowledgeGraph: batch = BatchStatement() count = 0 - for d, s, p, o in quads: - batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o)) + for d, s, p, o, otype, dtype, lang in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang)) count += 1 # Execute batch every 50 quads diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 4047a9e3..66bfe31f 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -84,14 +84,14 @@ class DocVectors: dim=dimension, ) - doc_field = FieldSchema( - name="doc", + chunk_id_field = FieldSchema( + name="chunk_id", dtype=DataType.VARCHAR, max_length=65535, ) schema = CollectionSchema( - fields = [pkey_field, vec_field, doc_field], + fields = [pkey_field, vec_field, chunk_id_field], description = "Document embedding schema", ) @@ -119,17 +119,17 @@ class DocVectors: self.collections[(dimension, user, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, doc, user, collection): + def insert(self, embeds, chunk_id, user, collection): dim = len(embeds) if (dim, user, collection) not in self.collections: self.init_collection(dim, user, collection) - + data = [ { "vector": embeds, - "doc": doc, + "chunk_id": chunk_id, } ] @@ -138,7 +138,7 @@ class DocVectors: data=data ) - def search(self, embeds, user, collection, fields=["doc"], limit=10): + def search(self, embeds, user, collection, fields=["chunk_id"], limit=10): dim = len(embeds) diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index 4a106f27..dcbf6734 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -90,8 +90,14 @@ class EntityVectors: max_length=65535, ) + chunk_id_field = FieldSchema( + name="chunk_id", + dtype=DataType.VARCHAR, + max_length=65535, + ) + schema = CollectionSchema( - fields = [pkey_field, vec_field, entity_field], + fields = [pkey_field, vec_field, entity_field, chunk_id_field], description = "Graph embedding schema", ) @@ -119,17 +125,18 @@ class EntityVectors: self.collections[(dimension, user, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, entity, user, collection): + def insert(self, embeds, entity, user, collection, chunk_id=""): dim = len(embeds) if (dim, user, collection) not in self.collections: self.init_collection(dim, user, collection) - + data = [ { "vector": embeds, "entity": entity, + "chunk_id": chunk_id, } ] diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index 602f7bb8..16ca1ad9 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -62,16 +62,17 @@ class Processor(FlowProcessor): resp = await flow("embeddings-request").request( EmbeddingsRequest( - text = v.chunk + texts=[v.chunk] ) ) - vectors = resp.vectors + # vectors[0] is the vector for the first (only) text + vector = resp.vectors[0] if resp.vectors else [] embeds = [ ChunkEmbeddings( - chunk=v.chunk, - vectors=vectors, + chunk_id=v.document_id, + vector=vector, ) ] diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index d1ce93ca..1a03ac9f 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -46,19 +46,21 @@ class Processor(EmbeddingsService): else: logger.debug(f"Using cached model: {model_name}") - async def on_embeddings(self, text, model=None): + async def on_embeddings(self, texts, model=None): + + if not texts: + return [] use_model = model or self.default_model # Reload model if it has changed self._load_model(use_model) - vecs = self.embeddings.embed([text]) + # FastEmbed processes the full batch efficiently + vecs = list(self.embeddings.embed(texts)) - return [ - v.tolist() - for v in vecs - ] + # Return list of vectors, one per input text + return [v.tolist() for v in vecs] @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index 1b63774d..3b441bd6 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -58,22 +58,25 @@ class Processor(FlowProcessor): v = msg.value() logger.info(f"Indexing {v.metadata.id}...") - entities = [] - try: - for entity in v.entities: + # Collect all contexts for batch embedding + contexts = [entity.context for entity in v.entities] - vectors = await flow("embeddings-request").embed( - text = entity.context - ) + # Single batch embedding call + all_vectors = await flow("embeddings-request").embed( + texts=contexts + ) - entities.append( - EntityEmbeddings( - entity=entity.entity, - vectors=vectors - ) + # Pair results with entities + entities = [ + EntityEmbeddings( + entity=entity.entity, + vector=vector, + chunk_id=entity.chunk_id, # Provenance: source chunk ) + for entity, vector in zip(v.entities, all_vectors) + ] # Send in batches to avoid oversized messages for i in range(0, len(entities), self.batch_size): diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c951252e..a65b4ff7 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -30,16 +30,21 @@ class Processor(EmbeddingsService): self.client = Client(host=ollama) self.default_model = model - async def on_embeddings(self, text, model=None): + async def on_embeddings(self, texts, model=None): + + if not texts: + return [] use_model = model or self.default_model + # Ollama handles batch input efficiently embeds = self.client.embed( model = use_model, - input = text + input = texts ) - return embeds.embeddings + # Return list of vectors, one per input text + return list(embeds.embeddings) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 84c41ff3..1365cb14 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor): embeddings_list = [] try: - for text, (index_name, index_value) in texts_to_embed.items(): - vectors = await flow("embeddings-request").embed(text=text) + # Collect texts and metadata for batch embedding + texts = list(texts_to_embed.keys()) + metadata = list(texts_to_embed.values()) + # Single batch embedding call + all_vectors = await flow("embeddings-request").embed(texts=texts) + + # Pair results with metadata + for text, (index_name, index_value), vector in zip( + texts, metadata, all_vectors + ): embeddings_list.append( RowIndexEmbedding( index_name=index_name, index_value=index_value, text=text, - vectors=vectors + vector=vector ) ) diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index d9057909..5ce343c6 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -6,11 +6,13 @@ import logging from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from ....schema import EntityContext, EntityContexts -from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION +from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, DEFINITION from ....base import FlowProcessor, ConsumerSpec, ProducerSpec from ....base import AgentClientSpec +from ....provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE +from ....flow_version import __version__ as COMPONENT_VERSION from ....template import PromptManager # Module logger @@ -104,7 +106,7 @@ class Processor(FlowProcessor): tpls = Triples( metadata = Metadata( id = metadata.id, - metadata = [], + root = metadata.root, user = metadata.user, collection = metadata.collection, ), @@ -117,7 +119,7 @@ class Processor(FlowProcessor): ecs = EntityContexts( metadata = Metadata( id = metadata.id, - metadata = [], + root = metadata.root, user = metadata.user, collection = metadata.collection, ), @@ -183,24 +185,8 @@ class Processor(FlowProcessor): logger.debug(f"Agent prompt: {prompt}") - async def handle(response): - - logger.debug(f"Agent response: {response}") - - if response.error is not None: - if response.error.message: - raise RuntimeError(str(response.error.message)) - else: - raise RuntimeError(str(response.error)) - - if response.answer is not None: - return True - else: - return False - # Send to agent API agent_response = await flow("agent-request").invoke( - recipient = handle, question = prompt ) @@ -212,14 +198,22 @@ class Processor(FlowProcessor): return # Process extraction data - triples, entity_contexts = self.process_extraction_data( - extraction_data, v.metadata - ) + triples, entity_contexts, extracted_triples = \ + self.process_extraction_data(extraction_data, v.metadata) + + # Generate subgraph provenance for extracted triples + if extracted_triples: + chunk_uri = v.metadata.id + sg_uri = subgraph_uri() + prov_triples = subgraph_provenance_triples( + subgraph_uri=sg_uri, + extracted_triples=extracted_triples, + chunk_uri=chunk_uri, + component_name=default_ident, + component_version=COMPONENT_VERSION, + ) + triples.extend(set_graph(prov_triples, GRAPH_SOURCE)) - # Put document metadata into triples - for t in v.metadata.metadata: - triples.append(t) - # Emit outputs if triples: await self.emit_triples(flow("triples"), v.metadata, triples) @@ -241,8 +235,13 @@ class Processor(FlowProcessor): Data is a flat list of objects with 'type' discriminator field: - {"type": "definition", "entity": "...", "definition": "..."} - {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool} + + Returns: + Tuple of (all_triples, entity_contexts, extracted_triples) where + extracted_triples contains only the core knowledge facts (for provenance). """ triples = [] + extracted_triples = [] entity_contexts = [] # Categorize items by type @@ -262,26 +261,20 @@ class Processor(FlowProcessor): )) # Add definition - triples.append(Triple( + definition_triple = Triple( s = Term(type=IRI, iri=entity_uri), p = Term(type=IRI, iri=DEFINITION), o = Term(type=LITERAL, value=defn["definition"]), - )) - - # Add subject-of relationship to document - if metadata.id: - triples.append(Triple( - s = Term(type=IRI, iri=entity_uri), - p = Term(type=IRI, iri=SUBJECT_OF), - o = Term(type=IRI, iri=metadata.id), - )) + ) + triples.append(definition_triple) + extracted_triples.append(definition_triple) # Create entity context for embeddings entity_contexts.append(EntityContext( entity=Term(type=IRI, iri=entity_uri), context=defn["definition"] )) - + # Process relationships for rel in relationships: @@ -318,34 +311,15 @@ class Processor(FlowProcessor): )) # Add the main relationship triple - triples.append(Triple( + relationship_triple = Triple( s = subject_value, p = predicate_value, o = object_value - )) + ) + triples.append(relationship_triple) + extracted_triples.append(relationship_triple) - # Add subject-of relationships to document - if metadata.id: - triples.append(Triple( - s = subject_value, - p = Term(type=IRI, iri=SUBJECT_OF), - o = Term(type=IRI, iri=metadata.id), - )) - - triples.append(Triple( - s = predicate_value, - p = Term(type=IRI, iri=SUBJECT_OF), - o = Term(type=IRI, iri=metadata.id), - )) - - if rel.get("object-entity", True): - triples.append(Triple( - s = object_value, - p = Term(type=IRI, iri=SUBJECT_OF), - o = Term(type=IRI, iri=metadata.id), - )) - - return triples, entity_contexts + return triples, entity_contexts, extracted_triples @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 72275a8c..2bb88c8a 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -15,15 +15,16 @@ from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL logger = logging.getLogger(__name__) from .... schema import EntityContext, EntityContexts from .... schema import PromptRequest, PromptResponse -from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF +from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base import PromptClientSpec +from .... base import PromptClientSpec, ParameterSpec + +from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE +from .... flow_version import __version__ as COMPONENT_VERSION DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION) RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL) -SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF) - default_ident = "kg-extract-definitions" default_concurrency = 1 default_triples_batch_size = 50 @@ -75,6 +76,10 @@ class Processor(FlowProcessor): ) ) + # Optional flow parameters for provenance + self.register_specification(ParameterSpec("llm-model")) + self.register_specification(ParameterSpec("ontology")) + def to_uri(self, text): part = text.replace(" ", "-").lower().encode("utf-8") @@ -126,12 +131,19 @@ class Processor(FlowProcessor): raise e triples = [] + extracted_triples = [] entities = [] - # FIXME: Putting metadata into triples store is duplicated in - # relationships extractor too - for t in v.metadata.metadata: - triples.append(t) + # Get chunk document ID for provenance linking + chunk_doc_id = v.document_id if v.document_id else v.metadata.id + chunk_uri = v.metadata.id # The URI form for the chunk + + # Get optional provenance parameters + llm_model = flow("llm-model") + ontology_uri = flow("ontology") + + # Note: Document metadata is now emitted once by librarian at processing + # initiation, so we don't need to duplicate it here. for defn in defs: @@ -155,28 +167,43 @@ class Processor(FlowProcessor): o=Term(type=LITERAL, value=s), )) - triples.append(Triple( + # The definition triple - this is the main extracted fact + definition_triple = Triple( s=s_value, p=DEFINITION_VALUE, o=o_value - )) - - triples.append(Triple( - s=s_value, - p=SUBJECT_OF_VALUE, - o=Term(type=IRI, iri=v.metadata.id) - )) + ) + triples.append(definition_triple) + extracted_triples.append(definition_triple) # Output entity name as context for direct name matching + # Include chunk_id for embedding provenance entities.append(EntityContext( entity=s_value, context=s, + chunk_id=chunk_doc_id, )) # Output definition as context for semantic matching + # Include chunk_id for embedding provenance entities.append(EntityContext( entity=s_value, context=defn["definition"], + chunk_id=chunk_doc_id, )) + # Generate subgraph provenance once for all extracted triples + if extracted_triples: + sg_uri = subgraph_uri() + prov_triples = subgraph_provenance_triples( + subgraph_uri=sg_uri, + extracted_triples=extracted_triples, + chunk_uri=chunk_uri, + component_name=default_ident, + component_version=COMPONENT_VERSION, + llm_model=llm_model, + ontology_uri=ontology_uri, + ) + triples.extend(set_graph(prov_triples, GRAPH_SOURCE)) + # Send triples in batches for i in range(0, len(triples), self.triples_batch_size): batch = triples[i:i + self.triples_batch_size] @@ -184,7 +211,7 @@ class Processor(FlowProcessor): flow("triples"), Metadata( id=v.metadata.id, - metadata=[], + root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, ), @@ -198,7 +225,7 @@ class Processor(FlowProcessor): flow("entity-contexts"), Metadata( id=v.metadata.id, - metadata=[], + root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, ), diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index a0d9a3fe..5078d817 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -23,6 +23,9 @@ from .ontology_selector import OntologySelector, OntologySubset from .simplified_parser import parse_extraction_response from .triple_converter import TripleConverter +from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE +from .... flow_version import __version__ as COMPONENT_VERSION + logger = logging.getLogger(__name__) default_ident = "kg-extract-ontology" @@ -148,8 +151,8 @@ class Processor(FlowProcessor): # Detect embedding dimension by embedding a test string logger.info("Detecting embedding dimension from embeddings service...") - test_embedding_response = await embeddings_client.embed("test") - test_embedding = test_embedding_response[0] # Extract from [[vector]] + test_embedding_response = await embeddings_client.embed(["test"]) + test_embedding = test_embedding_response[0] # Extract first vector dimension = len(test_embedding) logger.info(f"Detected embedding dimension: {dimension}") @@ -306,15 +309,25 @@ class Processor(FlowProcessor): flow, chunk, ontology_subset, prompt_variables ) - # Add metadata triples - for t in v.metadata.metadata: - triples.append(t) + # Generate subgraph provenance for extracted triples + if triples: + chunk_uri = v.metadata.id + sg_uri = subgraph_uri() + prov_triples = subgraph_provenance_triples( + subgraph_uri=sg_uri, + extracted_triples=triples, + chunk_uri=chunk_uri, + component_name=default_ident, + component_version=COMPONENT_VERSION, + ) # Generate ontology definition triples ontology_triples = self.build_ontology_triples(ontology_subset) - # Combine extracted triples with ontology triples + # Combine extracted triples with ontology triples and provenance all_triples = triples + ontology_triples + if triples: + all_triples.extend(set_graph(prov_triples, GRAPH_SOURCE)) # Build entity contexts from all triples (including ontology elements) entity_contexts = self.build_entity_contexts(all_triples) @@ -558,7 +571,7 @@ class Processor(FlowProcessor): t = Triples( metadata=Metadata( id=metadata.id, - metadata=[], + root=metadata.root, user=metadata.user, collection=metadata.collection, ), @@ -571,7 +584,7 @@ class Processor(FlowProcessor): ec = EntityContexts( metadata=Metadata( id=metadata.id, - metadata=[], + root=metadata.root, user=metadata.user, collection=metadata.collection, ), diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py index 8eee76b4..64127487 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py @@ -153,16 +153,11 @@ class OntologyEmbedder: # Get embeddings for batch texts = [elem['text'] for elem in batch] try: - # Call embedding service for each text - # Note: embed() returns 2D array [[vector]], so extract first element - embedding_tasks = [self.embedding_service.embed(text) for text in texts] - embeddings_responses = await asyncio.gather(*embedding_tasks) - - # Extract vectors from responses (each is [[vector]]) - embeddings_list = [resp[0] for resp in embeddings_responses] + # Single batch embedding call - returns list of vectors + embeddings_response = await self.embedding_service.embed(texts) # Convert to numpy array - embeddings = np.array(embeddings_list) + embeddings = np.array(embeddings_response) # Log embedding shape for debugging logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})") @@ -218,8 +213,8 @@ class OntologyEmbedder: return None try: - # embed() returns 2D array [[vector]], extract first element - embedding_response = await self.embedding_service.embed(text) + # embed() with single text, extract first vector + embedding_response = await self.embedding_service.embed([text]) return np.array(embedding_response[0]) except Exception as e: logger.error(f"Failed to embed text: {e}") @@ -239,12 +234,9 @@ class OntologyEmbedder: return None try: - # Call embed() for each text (returns [[vector]] per call) - embedding_tasks = [self.embedding_service.embed(text) for text in texts] - embeddings_responses = await asyncio.gather(*embedding_tasks) - # Extract first vector from each response - embeddings_list = [resp[0] for resp in embeddings_responses] - return np.array(embeddings_list) + # Single batch embedding call - returns list of vectors + embeddings_response = await self.embedding_service.embed(texts) + return np.array(embeddings_response) except Exception as e: logger.error(f"Failed to embed texts: {e}") return None diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 7ab51555..b557ec32 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -15,13 +15,15 @@ logger = logging.getLogger(__name__) from .... schema import Chunk, Triple, Triples from .... schema import Metadata, Term, IRI, LITERAL from .... schema import PromptRequest, PromptResponse -from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF +from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base import PromptClientSpec +from .... base import PromptClientSpec, ParameterSpec + +from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE +from .... flow_version import __version__ as COMPONENT_VERSION RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL) -SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF) default_ident = "kg-extract-relationships" default_concurrency = 1 @@ -65,6 +67,10 @@ class Processor(FlowProcessor): ) ) + # Optional flow parameters for provenance + self.register_specification(ParameterSpec("llm-model")) + self.register_specification(ParameterSpec("ontology")) + def to_uri(self, text): part = text.replace(" ", "-").lower().encode("utf-8") @@ -108,11 +114,18 @@ class Processor(FlowProcessor): raise e triples = [] + extracted_triples = [] - # FIXME: Putting metadata into triples store is duplicated in - # relationships extractor too - for t in v.metadata.metadata: - triples.append(t) + # Get chunk document ID for provenance linking + chunk_doc_id = v.document_id if v.document_id else v.metadata.id + chunk_uri = v.metadata.id # The URI form for the chunk + + # Get optional provenance parameters + llm_model = flow("llm-model") + ontology_uri = flow("ontology") + + # Note: Document metadata is now emitted once by librarian at processing + # initiation, so we don't need to duplicate it here. for rel in rels: @@ -140,11 +153,14 @@ class Processor(FlowProcessor): else: o_value = Term(type=LITERAL, value=str(o)) - triples.append(Triple( + # The relationship triple - this is the main extracted fact + relationship_triple = Triple( s=s_value, p=p_value, o=o_value - )) + ) + triples.append(relationship_triple) + extracted_triples.append(relationship_triple) # Label for s triples.append(Triple( @@ -168,20 +184,19 @@ class Processor(FlowProcessor): o=Term(type=LITERAL, value=str(o)) )) - # 'Subject of' for s - triples.append(Triple( - s=s_value, - p=SUBJECT_OF_VALUE, - o=Term(type=IRI, iri=v.metadata.id) - )) - - if rel["object-entity"]: - # 'Subject of' for o - triples.append(Triple( - s=o_value, - p=SUBJECT_OF_VALUE, - o=Term(type=IRI, iri=v.metadata.id) - )) + # Generate subgraph provenance once for all extracted triples + if extracted_triples: + sg_uri = subgraph_uri() + prov_triples = subgraph_provenance_triples( + subgraph_uri=sg_uri, + extracted_triples=extracted_triples, + chunk_uri=chunk_uri, + component_name=default_ident, + component_version=COMPONENT_VERSION, + llm_model=llm_model, + ontology_uri=ontology_uri, + ) + triples.extend(set_graph(prov_triples, GRAPH_SOURCE)) # Send triples in batches for i in range(0, len(triples), self.triples_batch_size): @@ -190,7 +205,7 @@ class Processor(FlowProcessor): flow("triples"), Metadata( id=v.metadata.id, - metadata=[], + root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, ), diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index bd7bc802..88e29116 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -272,7 +272,7 @@ class Processor(FlowProcessor): extracted = ExtractedObject( metadata=Metadata( id=f"{v.metadata.id}:{schema_name}", - metadata=[], + root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, ), diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py new file mode 100644 index 00000000..e70bf6de --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py @@ -0,0 +1,65 @@ + +import asyncio +import uuid +import logging +from . librarian import LibrarianRequestor + +# Module logger +logger = logging.getLogger(__name__) + +class DocumentStreamExport: + + def __init__(self, backend): + self.backend = backend + + async def process(self, data, error, ok, request): + + user = request.query.get("user") + document_id = request.query.get("document-id") + chunk_size = int(request.query.get("chunk-size", 1024 * 1024)) + + if not user or not document_id: + return await error("Missing required parameters: user, document-id") + + response = await ok() + + lr = LibrarianRequestor( + backend=self.backend, + consumer="api-gateway-doc-stream-" + str(uuid.uuid4()), + subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()), + ) + + try: + + await lr.start() + + async def responder(resp, fin): + if "content" in resp: + content = resp["content"] + # Content is base64 encoded, write as-is for client to decode + # Or decode here and write raw bytes + import base64 + chunk_data = base64.b64decode(content) + await response.write(chunk_data) + + await lr.process( + { + "operation": "stream-document", + "user": user, + "document-id": document_id, + "chunk-size": chunk_size, + }, + responder + ) + + except Exception as e: + + logger.error(f"Document stream exception: {e}", exc_info=True) + + finally: + + await lr.stop() + + await response.write_eof() + + return response diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 35edad76..d068ecef 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -45,6 +45,7 @@ from . rows_import import RowsImport from . core_export import CoreExport from . core_import import CoreImport +from . document_stream import DocumentStreamExport from . mux import Mux @@ -135,6 +136,14 @@ class DispatcherManager: def dispatch_core_import(self): return DispatcherWrapper(self.process_core_import) + def dispatch_document_stream(self): + return DispatcherWrapper(self.process_document_stream) + + async def process_document_stream(self, data, error, ok, request): + + ds = DocumentStreamExport(self.backend) + return await ds.process(data, error, ok, request) + async def process_core_import(self, data, error, ok, request): ci = CoreImport(self.backend) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index 6606dc1a..ad634cab 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -53,7 +53,6 @@ class RowsImport: elt = ExtractedObject( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"].get("metadata", [])), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 8f1cdece..f42eee02 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -37,35 +37,37 @@ def serialize_triples(message): return { "metadata": { "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), + "root": message.metadata.root, "user": message.metadata.user, "collection": message.metadata.collection, }, "triples": serialize_subgraph(message.triples), } - + + def serialize_graph_embeddings(message): return { "metadata": { "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), + "root": message.metadata.root, "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ { - "vectors": entity.vectors, + "vector": entity.vector, "entity": serialize_value(entity.entity), } for entity in message.entities ], } + def serialize_entity_contexts(message): return { "metadata": { "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), + "root": message.metadata.root, "user": message.metadata.user, "collection": message.metadata.collection, }, @@ -78,18 +80,19 @@ def serialize_entity_contexts(message): ], } + def serialize_document_embeddings(message): return { "metadata": { "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), + "root": message.metadata.root, "user": message.metadata.user, "collection": message.metadata.collection, }, "chunks": [ { - "vectors": chunk.vectors, - "chunk": chunk.chunk.decode("utf-8"), + "vector": chunk.vector, + "chunk_id": chunk.chunk_id, } for chunk in message.chunks ], diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 6bb46975..37f123fa 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -48,7 +48,7 @@ class TriplesImport: elt = Triples( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), + root=data["metadata"].get("root", ""), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index f9616d9a..ff92dca2 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -64,6 +64,12 @@ class EndpointManager: method = "GET", dispatcher = dispatcher_manager.dispatch_core_export(), ), + StreamEndpoint( + endpoint_path = "/api/v1/document-stream", + auth = auth, + method = "GET", + dispatcher = dispatcher_manager.dispatch_document_stream(), + ), ] def add_routes(self, app): diff --git a/trustgraph-flow/trustgraph/librarian/blob_store.py b/trustgraph-flow/trustgraph/librarian/blob_store.py index 436e2718..d75a7af9 100644 --- a/trustgraph-flow/trustgraph/librarian/blob_store.py +++ b/trustgraph-flow/trustgraph/librarian/blob_store.py @@ -3,9 +3,12 @@ from .. knowledge import hash from .. exceptions import RequestError from minio import Minio +from minio.datatypes import Part import time import io import logging +from typing import Iterator, List, Tuple +from uuid import UUID # Module logger logger = logging.getLogger(__name__) @@ -78,3 +81,163 @@ class BlobStore: return resp.read() + async def get_range(self, object_id, offset: int, length: int) -> bytes: + """Fetch a specific byte range from an object.""" + resp = self.client.get_object( + bucket_name=self.bucket_name, + object_name="doc/" + str(object_id), + offset=offset, + length=length, + ) + try: + return resp.read() + finally: + resp.close() + resp.release_conn() + + async def get_size(self, object_id) -> int: + """Get the size of an object without downloading it.""" + stat = self.client.stat_object( + bucket_name=self.bucket_name, + object_name="doc/" + str(object_id), + ) + return stat.size + + def get_stream(self, object_id, chunk_size: int = 1024 * 1024) -> Iterator[bytes]: + """ + Stream document content in chunks. + + Yields chunks of the document, allowing processing without loading + the entire document into memory. + + Args: + object_id: The UUID of the document object + chunk_size: Size of each chunk in bytes (default 1MB) + + Yields: + Chunks of document content as bytes + """ + resp = self.client.get_object( + bucket_name=self.bucket_name, + object_name="doc/" + str(object_id), + ) + + try: + while True: + chunk = resp.read(chunk_size) + if not chunk: + break + yield chunk + finally: + resp.close() + resp.release_conn() + + logger.debug("Stream complete") + + def create_multipart_upload(self, object_id: UUID, kind: str) -> str: + """ + Initialize a multipart upload. + + Args: + object_id: The UUID for the new object + kind: MIME type of the document + + Returns: + The S3 upload_id for this multipart upload session + """ + object_name = "doc/" + str(object_id) + + # Use minio's internal method to create multipart upload + upload_id = self.client._create_multipart_upload( + bucket_name=self.bucket_name, + object_name=object_name, + headers={"Content-Type": kind}, + ) + + logger.info(f"Created multipart upload {upload_id} for {object_id}") + return upload_id + + def upload_part( + self, + object_id: UUID, + upload_id: str, + part_number: int, + data: bytes + ) -> str: + """ + Upload a single part of a multipart upload. + + Args: + object_id: The UUID of the object being uploaded + upload_id: The S3 upload_id from create_multipart_upload + part_number: Part number (1-indexed, as per S3 spec) + data: The chunk data to upload + + Returns: + The ETag for this part (needed for complete_multipart_upload) + """ + object_name = "doc/" + str(object_id) + + etag = self.client._upload_part( + bucket_name=self.bucket_name, + object_name=object_name, + data=data, + headers={"Content-Length": str(len(data))}, + upload_id=upload_id, + part_number=part_number, + ) + + logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}") + return etag + + def complete_multipart_upload( + self, + object_id: UUID, + upload_id: str, + parts: List[Tuple[int, str]] + ) -> None: + """ + Complete a multipart upload, assembling all parts into the final object. + + S3 coalesces the parts server-side - no data transfer through this client. + + Args: + object_id: The UUID of the object + upload_id: The S3 upload_id from create_multipart_upload + parts: List of (part_number, etag) tuples in order + """ + object_name = "doc/" + str(object_id) + + # Convert to Part objects as expected by minio + part_objects = [ + Part(part_number, etag) + for part_number, etag in parts + ] + + self.client._complete_multipart_upload( + bucket_name=self.bucket_name, + object_name=object_name, + upload_id=upload_id, + parts=part_objects, + ) + + logger.info(f"Completed multipart upload for {object_id}") + + def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None: + """ + Abort a multipart upload, cleaning up any uploaded parts. + + Args: + object_id: The UUID of the object + upload_id: The S3 upload_id from create_multipart_upload + """ + object_name = "doc/" + str(object_id) + + self.client._abort_multipart_upload( + bucket_name=self.bucket_name, + object_name=object_name, + upload_id=upload_id, + ) + + logger.info(f"Aborted multipart upload {upload_id} for {object_id}") + diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 8835cc73..829e63ba 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -1,17 +1,24 @@ from .. schema import LibrarianRequest, LibrarianResponse, Error, Triple +from .. schema import UploadSession from .. knowledge import hash from .. exceptions import RequestError from .. tables.library import LibraryTableStore from . blob_store import BlobStore import base64 +import json import logging +import math +import time import uuid # Module logger logger = logging.getLogger(__name__) +# Default chunk size for multipart uploads +DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB default + class Librarian: def __init__( @@ -20,6 +27,7 @@ class Librarian: object_store_endpoint, object_store_access_key, object_store_secret_key, bucket_name, keyspace, load_document, object_store_use_ssl=False, object_store_region=None, + min_chunk_size=1, # Default: no minimum (for Garage) ): self.blob_store = BlobStore( @@ -32,6 +40,7 @@ class Librarian: ) self.load_document = load_document + self.min_chunk_size = min_chunk_size async def add_document(self, request): @@ -66,13 +75,7 @@ class Librarian: logger.debug("Add complete") - return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, - processing_metadatas = None, - ) + return LibrarianResponse() async def remove_document(self, request): @@ -84,6 +87,21 @@ class Librarian: ): raise RuntimeError("Document does not exist") + # First, cascade delete all child documents + children = await self.table_store.list_children(request.document_id) + for child in children: + logger.debug(f"Cascade deleting child document {child.id}") + try: + child_object_id = await self.table_store.get_document_object_id( + child.user, + child.id + ) + await self.blob_store.remove(child_object_id) + await self.table_store.remove_document(child.user, child.id) + except Exception as e: + logger.warning(f"Failed to delete child document {child.id}: {e}") + + # Now remove the parent document object_id = await self.table_store.get_document_object_id( request.user, request.document_id @@ -100,13 +118,7 @@ class Librarian: logger.debug("Remove complete") - return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, - processing_metadatas = None, - ) + return LibrarianResponse() async def update_document(self, request): @@ -124,13 +136,7 @@ class Librarian: logger.debug("Update complete") - return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, - processing_metadatas = None, - ) + return LibrarianResponse() async def get_document_metadata(self, request): @@ -147,8 +153,6 @@ class Librarian: error = None, document_metadata = doc, content = None, - document_metadatas = None, - processing_metadatas = None, ) async def get_document_content(self, request): @@ -170,8 +174,6 @@ class Librarian: error = None, document_metadata = None, content = base64.b64encode(content), - document_metadatas = None, - processing_metadatas = None, ) async def add_processing(self, request): @@ -217,13 +219,7 @@ class Librarian: logger.debug("Add complete") - return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, - processing_metadatas = None, - ) + return LibrarianResponse() async def remove_processing(self, request): @@ -243,24 +239,23 @@ class Librarian: logger.debug("Remove complete") - return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, - processing_metadatas = None, - ) + return LibrarianResponse() async def list_documents(self, request): docs = await self.table_store.list_documents(request.user) + # Filter out child documents and answer documents by default + include_children = getattr(request, 'include_children', False) + if not include_children: + docs = [ + doc for doc in docs + if not doc.parent_id # Only include top-level documents + and doc.document_type != "answer" # Exclude GraphRAG answers + ] + return LibrarianResponse( - error = None, - document_metadata = None, - content = None, document_metadatas = docs, - processing_metadatas = None, ) async def list_processing(self, request): @@ -268,10 +263,444 @@ class Librarian: procs = await self.table_store.list_processing(request.user) return LibrarianResponse( - error = None, - document_metadata = None, - content = None, - document_metadatas = None, processing_metadatas = procs, ) + # Chunked upload operations + + async def begin_upload(self, request): + """ + Initialize a chunked upload session. + + Creates an S3 multipart upload and stores session state in Cassandra. + """ + logger.info(f"Beginning chunked upload for document {request.document_metadata.id}") + + if request.document_metadata.kind not in ("text/plain", "application/pdf"): + raise RequestError( + "Invalid document kind: " + request.document_metadata.kind + ) + + if await self.table_store.document_exists( + request.document_metadata.user, + request.document_metadata.id + ): + raise RequestError("Document already exists") + + # Validate sizes + total_size = request.total_size + if total_size <= 0: + raise RequestError("total_size must be positive") + + # Use provided chunk size or default + chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE + if chunk_size < self.min_chunk_size: + raise RequestError( + f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}" + ) + + # Calculate total chunks + total_chunks = math.ceil(total_size / chunk_size) + + # Generate IDs + upload_id = str(uuid.uuid4()) + object_id = uuid.uuid4() + + # Create S3 multipart upload + s3_upload_id = self.blob_store.create_multipart_upload( + object_id, request.document_metadata.kind + ) + + # Serialize document metadata for storage + doc_meta_json = json.dumps({ + "id": request.document_metadata.id, + "time": request.document_metadata.time, + "kind": request.document_metadata.kind, + "title": request.document_metadata.title, + "comments": request.document_metadata.comments, + "user": request.document_metadata.user, + "tags": request.document_metadata.tags, + }) + + # Store session in Cassandra + await self.table_store.create_upload_session( + upload_id=upload_id, + user=request.document_metadata.user, + document_id=request.document_metadata.id, + document_metadata=doc_meta_json, + s3_upload_id=s3_upload_id, + object_id=object_id, + total_size=total_size, + chunk_size=chunk_size, + total_chunks=total_chunks, + ) + + logger.info(f"Created upload session {upload_id} with {total_chunks} chunks") + + return LibrarianResponse( + error=None, + upload_id=upload_id, + chunk_size=chunk_size, + total_chunks=total_chunks, + ) + + async def upload_chunk(self, request): + """ + Upload a single chunk of a document. + + Forwards the chunk to S3 and updates session state. + """ + logger.debug(f"Uploading chunk {request.chunk_index} for upload {request.upload_id}") + + # Get session + session = await self.table_store.get_upload_session(request.upload_id) + if session is None: + raise RequestError("Upload session not found or expired") + + # Validate ownership + if session["user"] != request.user: + raise RequestError("Not authorized to upload to this session") + + # Validate chunk index + if request.chunk_index < 0 or request.chunk_index >= session["total_chunks"]: + raise RequestError( + f"Invalid chunk index {request.chunk_index}, " + f"must be 0-{session['total_chunks']-1}" + ) + + # Decode content + content = base64.b64decode(request.content) + + # Upload to S3 (part numbers are 1-indexed in S3) + part_number = request.chunk_index + 1 + etag = self.blob_store.upload_part( + object_id=session["object_id"], + upload_id=session["s3_upload_id"], + part_number=part_number, + data=content, + ) + + # Update session with chunk info + await self.table_store.update_upload_session_chunk( + upload_id=request.upload_id, + chunk_index=request.chunk_index, + etag=etag, + ) + + # Calculate progress + chunks_received = session["chunks_received"] + # Add this chunk if not already present + if request.chunk_index not in chunks_received: + chunks_received[request.chunk_index] = etag + + num_chunks_received = len(chunks_received) + 1 # +1 for this chunk + bytes_received = num_chunks_received * session["chunk_size"] + # Adjust for last chunk potentially being smaller + if bytes_received > session["total_size"]: + bytes_received = session["total_size"] + + logger.debug(f"Chunk {request.chunk_index} uploaded, {num_chunks_received}/{session['total_chunks']} complete") + + return LibrarianResponse( + error=None, + upload_id=request.upload_id, + chunk_index=request.chunk_index, + chunks_received=num_chunks_received, + total_chunks=session["total_chunks"], + bytes_received=bytes_received, + total_bytes=session["total_size"], + ) + + async def complete_upload(self, request): + """ + Finalize a chunked upload and create the document. + + Completes the S3 multipart upload and creates the document metadata. + """ + logger.info(f"Completing upload {request.upload_id}") + + # Get session + session = await self.table_store.get_upload_session(request.upload_id) + if session is None: + raise RequestError("Upload session not found or expired") + + # Validate ownership + if session["user"] != request.user: + raise RequestError("Not authorized to complete this upload") + + # Verify all chunks received + chunks_received = session["chunks_received"] + if len(chunks_received) != session["total_chunks"]: + missing = [ + i for i in range(session["total_chunks"]) + if i not in chunks_received + ] + raise RequestError( + f"Missing chunks: {missing[:10]}{'...' if len(missing) > 10 else ''}" + ) + + # Build parts list for S3 (sorted by part number) + parts = [ + (chunk_index + 1, etag) # S3 part numbers are 1-indexed + for chunk_index, etag in sorted(chunks_received.items()) + ] + + # Complete S3 multipart upload + self.blob_store.complete_multipart_upload( + object_id=session["object_id"], + upload_id=session["s3_upload_id"], + parts=parts, + ) + + # Parse document metadata from session + doc_meta_dict = json.loads(session["document_metadata"]) + + # Create DocumentMetadata object + from .. schema import DocumentMetadata + doc_metadata = DocumentMetadata( + id=doc_meta_dict["id"], + time=doc_meta_dict.get("time", int(time.time())), + kind=doc_meta_dict["kind"], + title=doc_meta_dict.get("title", ""), + comments=doc_meta_dict.get("comments", ""), + user=doc_meta_dict["user"], + tags=doc_meta_dict.get("tags", []), + metadata=[], # Triples not supported in chunked upload yet + ) + + # Add document to table + await self.table_store.add_document(doc_metadata, session["object_id"]) + + # Delete upload session + await self.table_store.delete_upload_session(request.upload_id) + + logger.info(f"Upload {request.upload_id} completed, document {doc_metadata.id} created") + + return LibrarianResponse( + error=None, + document_id=doc_metadata.id, + object_id=str(session["object_id"]), + ) + + async def abort_upload(self, request): + """ + Cancel a chunked upload and clean up resources. + """ + logger.info(f"Aborting upload {request.upload_id}") + + # Get session + session = await self.table_store.get_upload_session(request.upload_id) + if session is None: + raise RequestError("Upload session not found or expired") + + # Validate ownership + if session["user"] != request.user: + raise RequestError("Not authorized to abort this upload") + + # Abort S3 multipart upload + self.blob_store.abort_multipart_upload( + object_id=session["object_id"], + upload_id=session["s3_upload_id"], + ) + + # Delete session from Cassandra + await self.table_store.delete_upload_session(request.upload_id) + + logger.info(f"Upload {request.upload_id} aborted") + + return LibrarianResponse(error=None) + + async def get_upload_status(self, request): + """ + Get the status of an in-progress upload. + """ + logger.debug(f"Getting status for upload {request.upload_id}") + + # Get session + session = await self.table_store.get_upload_session(request.upload_id) + if session is None: + return LibrarianResponse( + error=None, + upload_id=request.upload_id, + upload_state="expired", + ) + + # Validate ownership + if session["user"] != request.user: + raise RequestError("Not authorized to view this upload") + + chunks_received = session["chunks_received"] + received_list = sorted(chunks_received.keys()) + missing_list = [ + i for i in range(session["total_chunks"]) + if i not in chunks_received + ] + + bytes_received = len(chunks_received) * session["chunk_size"] + if bytes_received > session["total_size"]: + bytes_received = session["total_size"] + + return LibrarianResponse( + error=None, + upload_id=request.upload_id, + upload_state="in-progress", + received_chunks=received_list, + missing_chunks=missing_list, + chunks_received=len(chunks_received), + total_chunks=session["total_chunks"], + bytes_received=bytes_received, + total_bytes=session["total_size"], + ) + + async def list_uploads(self, request): + """ + List all in-progress uploads for a user. + """ + logger.debug(f"Listing uploads for user {request.user}") + + sessions = await self.table_store.list_upload_sessions(request.user) + + upload_sessions = [ + UploadSession( + upload_id=s["upload_id"], + document_id=s["document_id"], + document_metadata_json=s.get("document_metadata", ""), + total_size=s["total_size"], + chunk_size=s["chunk_size"], + total_chunks=s["total_chunks"], + chunks_received=s["chunks_received"], + created_at=str(s.get("created_at", "")), + ) + for s in sessions + ] + + return LibrarianResponse( + error=None, + upload_sessions=upload_sessions, + ) + + # Child document operations + + async def add_child_document(self, request): + """ + Add a child document linked to a parent document. + + Child documents are typically extracted content (e.g., pages from a PDF). + They have a parent_id pointing to the source document and document_type + set to "extracted". + """ + logger.info(f"Adding child document {request.document_metadata.id} " + f"for parent {request.document_metadata.parent_id}") + + if not request.document_metadata.parent_id: + raise RequestError("parent_id is required for child documents") + + # Verify parent exists + if not await self.table_store.document_exists( + request.document_metadata.user, + request.document_metadata.parent_id + ): + raise RequestError( + f"Parent document {request.document_metadata.parent_id} does not exist" + ) + + if await self.table_store.document_exists( + request.document_metadata.user, + request.document_metadata.id + ): + raise RequestError("Document already exists") + + # Set document_type if not specified by caller + # Valid types: "page", "chunk", or "extracted" (legacy) + if not request.document_metadata.document_type or request.document_metadata.document_type == "source": + request.document_metadata.document_type = "extracted" + + # Create object ID for blob + object_id = uuid.uuid4() + + logger.debug("Adding blob...") + + await self.blob_store.add( + object_id, base64.b64decode(request.content), + request.document_metadata.kind + ) + + logger.debug("Adding to table...") + + await self.table_store.add_document( + request.document_metadata, object_id + ) + + logger.debug("Add child document complete") + + return LibrarianResponse( + error=None, + document_id=request.document_metadata.id, + ) + + async def list_children(self, request): + """ + List all child documents for a given parent document. + """ + logger.debug(f"Listing children for parent {request.document_id}") + + children = await self.table_store.list_children(request.document_id) + + return LibrarianResponse( + error=None, + document_metadatas=children, + ) + + async def stream_document(self, request): + """ + Stream document content in chunks. + + This is an async generator that yields document content in smaller chunks, + allowing memory-efficient processing of large documents. Each yielded + response includes chunk_index and total_chunks for tracking progress. + Completion is determined by chunk_index reaching total_chunks - 1. + """ + logger.debug(f"Streaming document {request.document_id}") + + DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB default + + chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE + if chunk_size < self.min_chunk_size: + raise RequestError( + f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}" + ) + + object_id = await self.table_store.get_document_object_id( + request.user, + request.document_id + ) + + # Get size via stat (no content download) + total_size = await self.blob_store.get_size(object_id) + total_chunks = math.ceil(total_size / chunk_size) + + # Stream all chunks + for chunk_index in range(total_chunks): + # Calculate byte range + offset = chunk_index * chunk_size + length = min(chunk_size, total_size - offset) + + # Fetch only the requested range + chunk_content = await self.blob_store.get_range(object_id, offset, length) + + is_last = (chunk_index == total_chunks - 1) + + logger.debug(f"Streaming chunk {chunk_index + 1}/{total_chunks}, " + f"bytes {offset}-{offset + length} of {total_size}") + + yield LibrarianResponse( + error=None, + content=base64.b64encode(chunk_content), + chunk_index=chunk_index, + chunks_received=chunk_index + 1, + total_chunks=total_chunks, + bytes_received=offset + length, + total_bytes=total_size, + is_final=is_last, + ) + diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index 7c1e428c..e017a99d 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -23,9 +23,14 @@ from .. schema import config_request_queue, config_response_queue from .. schema import Document, Metadata from .. schema import TextDocument, Metadata +from .. schema import Triples from .. exceptions import RequestError +from .. provenance import ( + document_uri, document_triples, get_vocabulary_triples, +) + from . librarian import Librarian from . collection_manager import CollectionManager @@ -47,6 +52,7 @@ default_object_store_secret_key = "object-password" default_object_store_use_ssl = False default_object_store_region = None default_cassandra_host = "cassandra" +default_min_chunk_size = 1 # No minimum by default (for Garage) bucket_name = "library" @@ -100,6 +106,11 @@ class Processor(AsyncProcessor): default_object_store_region ) + min_chunk_size = params.get( + "min_chunk_size", + default_min_chunk_size + ) + cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") @@ -226,6 +237,7 @@ class Processor(AsyncProcessor): load_document = self.load_document, object_store_use_ssl = object_store_use_ssl, object_store_region = object_store_region, + min_chunk_size = min_chunk_size, ) self.collection_manager = CollectionManager( @@ -271,6 +283,70 @@ class Processor(AsyncProcessor): pass + # Threshold for sending document_id instead of inline content (2MB) + STREAMING_THRESHOLD = 2 * 1024 * 1024 + + async def emit_document_provenance(self, document, processing, triples_queue): + """ + Emit document provenance metadata to the knowledge graph. + + This emits: + 1. Vocabulary bootstrap triples (idempotent, safe to re-emit) + 2. Document metadata as PROV-O triples + """ + logger.debug(f"Emitting document provenance for {document.id}") + + # Build document URI and provenance triples + doc_uri = document_uri(document.id) + + # Get page count for PDFs (if available from document metadata) + page_count = None + if document.kind == "application/pdf": + # Page count might be in document metadata triples + # For now, we don't have it at this point - it gets determined during extraction + pass + + # Build document metadata triples + prov_triples = document_triples( + doc_uri=doc_uri, + title=document.title if document.title else None, + mime_type=document.kind, + ) + + # Include any existing metadata triples from the document + if document.metadata: + prov_triples.extend(document.metadata) + + # Get vocabulary bootstrap triples (idempotent) + vocab_triples = get_vocabulary_triples() + + # Combine all triples + all_triples = vocab_triples + prov_triples + + # Create publisher and emit + triples_pub = Publisher( + self.pubsub, triples_queue, schema=Triples + ) + + try: + await triples_pub.start() + + triples_msg = Triples( + metadata=Metadata( + id=doc_uri, + root=document.id, + user=processing.user, + collection=processing.collection, + ), + triples=all_triples, + ) + + await triples_pub.send(None, triples_msg) + logger.debug(f"Emitted {len(all_triples)} provenance triples for {document.id}") + + finally: + await triples_pub.stop() + async def load_document(self, document, processing, content): logger.debug("Ready for document processing...") @@ -291,27 +367,64 @@ class Processor(AsyncProcessor): q = flow["interfaces"][kind] - if kind == "text-load": - doc = TextDocument( - metadata = Metadata( - id = document.id, - metadata = document.metadata, - user = processing.user, - collection = processing.collection - ), - text = content, + # Emit document provenance to knowledge graph + if "triples-store" in flow["interfaces"]: + await self.emit_document_provenance( + document, processing, flow["interfaces"]["triples-store"] ) + + if kind == "text-load": + # For large text documents, send document_id for streaming retrieval + if len(content) >= self.STREAMING_THRESHOLD: + logger.info(f"Text document {document.id} is large ({len(content)} bytes), " + f"sending document_id for streaming retrieval") + doc = TextDocument( + metadata = Metadata( + id = document.id, + root = document.id, + user = processing.user, + collection = processing.collection + ), + document_id = document.id, + text = b"", # Empty, receiver will fetch via librarian + ) + else: + doc = TextDocument( + metadata = Metadata( + id = document.id, + root = document.id, + user = processing.user, + collection = processing.collection + ), + text = content, + ) schema = TextDocument else: - doc = Document( - metadata = Metadata( - id = document.id, - metadata = document.metadata, - user = processing.user, - collection = processing.collection - ), - data = base64.b64encode(content).decode("utf-8") - ) + # For large PDF documents, send document_id for streaming retrieval + # instead of embedding the entire content in the message + if len(content) >= self.STREAMING_THRESHOLD: + logger.info(f"Document {document.id} is large ({len(content)} bytes), " + f"sending document_id for streaming retrieval") + doc = Document( + metadata = Metadata( + id = document.id, + root = document.id, + user = processing.user, + collection = processing.collection + ), + document_id = document.id, + data = b"", # Empty data, receiver will fetch via API + ) + else: + doc = Document( + metadata = Metadata( + id = document.id, + root = document.id, + user = processing.user, + collection = processing.collection + ), + data = base64.b64encode(content).decode("utf-8") + ) schema = Document logger.debug(f"Submitting to queue {q}...") @@ -361,6 +474,17 @@ class Processor(AsyncProcessor): "remove-processing": self.librarian.remove_processing, "list-documents": self.librarian.list_documents, "list-processing": self.librarian.list_processing, + # Chunked upload operations + "begin-upload": self.librarian.begin_upload, + "upload-chunk": self.librarian.upload_chunk, + "complete-upload": self.librarian.complete_upload, + "abort-upload": self.librarian.abort_upload, + "get-upload-status": self.librarian.get_upload_status, + "list-uploads": self.librarian.list_uploads, + # Child document and streaming operations + "add-child-document": self.librarian.add_child_document, + "list-children": self.librarian.list_children, + "stream-document": self.librarian.stream_document, } if v.operation not in impls: @@ -380,6 +504,15 @@ class Processor(AsyncProcessor): try: + # Handle streaming operations specially + if v.operation == "stream-document": + async for resp in self.librarian.stream_document(v): + await self.librarian_response_producer.send( + resp, properties={"id": id} + ) + return + + # Non-streaming operations resp = await self.process_request(v) await self.librarian_response_producer.send( @@ -393,7 +526,7 @@ class Processor(AsyncProcessor): error = Error( type = "request-error", message = str(e), - ) + ), ) await self.librarian_response_producer.send( @@ -406,7 +539,7 @@ class Processor(AsyncProcessor): error = Error( type = "unexpected-error", message = str(e), - ) + ), ) await self.librarian_response_producer.send( @@ -538,6 +671,14 @@ class Processor(AsyncProcessor): help='Object storage region (optional)', ) + parser.add_argument( + '--min-chunk-size', + type=int, + default=default_min_chunk_size, + help=f'Minimum chunk size in bytes for uploads/downloads ' + f'(default: {default_min_chunk_size})', + ) + add_cassandra_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index 4e3db7f9..87d23621 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -55,11 +55,13 @@ class Processor(LlmService): self.max_output = max_output self.default_model = model - def build_prompt(self, system, content, temperature=None, stream=False): + def build_prompt(self, system, content, temperature=None, stream=False, model=None): # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature + model_name = model or self.default_model data = { + "model": model_name, "messages": [ { "role": "system", "content": system @@ -100,7 +102,8 @@ class Processor(LlmService): raise TooManyRequests() if resp.status_code != 200: - raise RuntimeError("LLM failure") + logger.error(f"Azure API error: status={resp.status_code}, body={resp.text}") + raise RuntimeError(f"LLM failure: HTTP {resp.status_code}") result = resp.json() @@ -121,7 +124,8 @@ class Processor(LlmService): prompt = self.build_prompt( system, prompt, - effective_temperature + effective_temperature, + model=model_name ) response = self.call_llm(prompt) @@ -174,7 +178,7 @@ class Processor(LlmService): logger.debug(f"Using temperature: {effective_temperature}") try: - body = self.build_prompt(system, prompt, effective_temperature, stream=True) + body = self.build_prompt(system, prompt, effective_temperature, stream=True, model=model_name) url = self.endpoint api_key = self.token @@ -190,7 +194,11 @@ class Processor(LlmService): raise TooManyRequests() if response.status_code != 200: - raise RuntimeError("LLM failure") + logger.error(f"Azure API error: status={response.status_code}, body={response.text}") + raise RuntimeError(f"LLM failure: HTTP {response.status_code}") + + total_input_tokens = 0 + total_output_tokens = 0 total_input_tokens = 0 total_output_tokens = 0 @@ -279,6 +287,12 @@ class Processor(LlmService): help=f'LLM max output tokens (default: {default_max_output})' ) + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model name (default: {default_model})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 03c98ad3..98350961 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -1,13 +1,13 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks +of chunk_ids """ import logging from .... direct.milvus_doc_embeddings import DocVectors -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -35,24 +35,31 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit + ) + chunks = [] - - for vec in msg.vectors: - - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit - ) - - for r in resp: - chunk = r["entity"]["doc"] - chunks.append(chunk) + for r in resp: + chunk_id = r["entity"]["chunk_id"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) return chunks diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 1c3f8d1b..406f979c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -1,7 +1,7 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks. Pinecone implementation. +of chunk_ids. Pinecone implementation. """ import logging @@ -11,6 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig +from .... schema import ChunkMatch from .... base import DocumentEmbeddingsQueryService # Module logger @@ -51,36 +52,41 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] + dim = len(vec) + + # Use dimension suffix in index name + index_name = f"d-{msg.user}-{msg.collection}-{dim}" + + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] + + index = self.pinecone.Index(index_name) + + results = index.query( + vector=vec, + top_k=msg.limit, + include_values=False, + include_metadata=True + ) + chunks = [] - - for vec in msg.vectors: - - dim = len(vec) - - # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" - - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue - - index = self.pinecone.Index(index_name) - - results = index.query( - vector=vec, - top_k=msg.limit, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - doc = r.metadata["doc"] - chunks.append(doc) + for r in results.matches: + chunk_id = r.metadata["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) return chunks diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index e84372cb..f056b1c1 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -1,7 +1,7 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks +of chunk_ids """ import logging @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -69,29 +69,34 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + + # Use dimension suffix in collection name + dim = len(vec) + collection = f"d_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist, returning empty results") + return [] + + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit, + with_payload=True, + ).points + chunks = [] - - for vec in msg.vectors: - - # Use dimension suffix in collection name - dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, returning empty results") - continue - - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit, - with_payload=True, - ).points - - for r in search_result: - ent = r.payload["doc"] - chunks.append(ent) + for r in search_result: + chunk_id = r.payload["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) return chunks diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index c5cdb6d8..94eee387 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -7,7 +7,7 @@ entities import logging from .... direct.milvus_graph_embeddings import EntityVectors -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService): try: - entity_set = set() - entities = [] + vec = msg.vector + if not vec: + return [] # Handle zero limit case if msg.limit <= 0: return [] - for vec in msg.vectors: + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit * 2 + ) - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit * 2 - ) + entity_set = set() + entities = [] - for r in resp: - ent = r["entity"]["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) + for r in resp: + ent = r["entity"]["entity"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 5882f21c..ca443a6f 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -11,7 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] + dim = len(vec) + + # Use dimension suffix in index name + index_name = f"t-{msg.user}-{msg.collection}-{dim}" + + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] + + index = self.pinecone.Index(index_name) + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + results = index.query( + vector=vec, + top_k=msg.limit * 2, + include_values=False, + include_metadata=True + ) + entity_set = set() entities = [] - for vec in msg.vectors: + for r in results.matches: + ent = r.metadata["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - dim = len(vec) - - # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" - - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue - - index = self.pinecone.Index(index_name) - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - results = index.query( - vector=vec, - top_k=msg.limit * 2, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - - ent = r.metadata["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index a76059ef..df93ad8b 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + + # Use dimension suffix in collection name + dim = len(vec) + collection = f"t_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist") + return [] + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit * 2, + with_payload=True, + ).points + entity_set = set() entities = [] - for vec in msg.vectors: + for r in search_result: + ent = r.payload["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - # Use dimension suffix in collection name - dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, skipping this vector") - continue - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit * 2, - with_payload=True, - ).points - - for r in search_result: - ent = r.payload["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7ed6192f..7fc20303 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) default_ident = "row-embeddings-query" default_store_uri = 'http://localhost:6333' +default_concurrency = 10 class Processor(FlowProcessor): @@ -31,6 +32,7 @@ class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) store_uri = params.get("store_uri", default_store_uri) api_key = params.get("api_key", None) @@ -47,7 +49,8 @@ class Processor(FlowProcessor): ConsumerSpec( name="request", schema=RowEmbeddingsRequest, - handler=self.on_message + handler=self.on_message, + concurrency=concurrency, ) ) @@ -93,7 +96,9 @@ class Processor(FlowProcessor): async def query_row_embeddings(self, request: RowEmbeddingsRequest): """Execute row embeddings query""" - matches = [] + vec = request.vector + if not vec: + return [] # Find the collection for this user/collection/schema qdrant_collection = self.find_collection( @@ -105,47 +110,47 @@ class Processor(FlowProcessor): f"No Qdrant collection found for " f"{request.user}/{request.collection}/{request.schema_name}" ) + return [] + + try: + # Build optional filter for index_name + query_filter = None + if request.index_name: + query_filter = Filter( + must=[ + FieldCondition( + key="index_name", + match=MatchValue(value=request.index_name) + ) + ] + ) + + # Query Qdrant + search_result = self.qdrant.query_points( + collection_name=qdrant_collection, + query=vec, + limit=request.limit, + with_payload=True, + query_filter=query_filter, + ).points + + # Convert to RowIndexMatch objects + matches = [] + for point in search_result: + payload = point.payload or {} + match = RowIndexMatch( + index_name=payload.get("index_name", ""), + index_value=payload.get("index_value", []), + text=payload.get("text", ""), + score=point.score if hasattr(point, 'score') else 0.0 + ) + matches.append(match) + return matches - for vec in request.vectors: - try: - # Build optional filter for index_name - query_filter = None - if request.index_name: - query_filter = Filter( - must=[ - FieldCondition( - key="index_name", - match=MatchValue(value=request.index_name) - ) - ] - ) - - # Query Qdrant - search_result = self.qdrant.query_points( - collection_name=qdrant_collection, - query=vec, - limit=request.limit, - with_payload=True, - query_filter=query_filter, - ).points - - # Convert to RowIndexMatch objects - for point in search_result: - payload = point.payload or {} - match = RowIndexMatch( - index_name=payload.get("index_name", ""), - index_value=payload.get("index_value", []), - text=payload.get("text", ""), - score=point.score if hasattr(point, 'score') else 0.0 - ) - matches.append(match) - - except Exception as e: - logger.error(f"Failed to query Qdrant: {e}", exc_info=True) - raise - - return matches + except Exception as e: + logger.error(f"Failed to query Qdrant: {e}", exc_info=True) + raise async def on_message(self, msg, consumer, flow): """Handle incoming query request""" @@ -203,6 +208,13 @@ class Processor(FlowProcessor): help='API key for Qdrant (default: None)' ) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): """Entry point for row-embeddings-query-qdrant command""" diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 3808cdb0..2337642f 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection logger = logging.getLogger(__name__) default_ident = "rows-query" +default_concurrency = 10 class Processor(FlowProcessor): @@ -37,6 +38,7 @@ class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) # Get Cassandra parameters cassandra_host = params.get("cassandra_host") @@ -69,7 +71,8 @@ class Processor(FlowProcessor): ConsumerSpec( name="request", schema=RowsQueryRequest, - handler=self.on_message + handler=self.on_message, + concurrency=concurrency, ) ) @@ -517,6 +520,13 @@ class Processor(FlowProcessor): help='Configuration type prefix for schemas (default: schema)' ) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): """Entry point for rows-query-cassandra command""" diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index eac33dde..f1f5ba60 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,11 +6,14 @@ null. Output is a list of quads. import logging +import json +from cassandra.query import SimpleStatement + from .... direct.cassandra_kg import ( - EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH + EntityCentricKnowledgeGraph, DEFAULT_GRAPH ) from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Term, Triple, IRI, LITERAL +from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config @@ -20,6 +23,36 @@ logger = logging.getLogger(__name__) default_ident = "triples-query" +def serialize_triple(triple): + """Serialize a Triple object to JSON for querying (must match storage format).""" + if triple is None: + return None + + def term_to_dict(term): + if term is None: + return None + result = {"type": term.type} + if term.type == IRI: + result["iri"] = term.iri + elif term.type == LITERAL: + result["value"] = term.value + if term.datatype: + result["datatype"] = term.datatype + if term.language: + result["language"] = term.language + elif term.type == BLANK: + result["id"] = term.id + elif term.type == TRIPLE: + result["triple"] = serialize_triple(term.triple) + return result + + return json.dumps({ + "s": term_to_dict(triple.s), + "p": term_to_dict(triple.p), + "o": term_to_dict(triple.o), + }) + + def get_term_value(term): """Extract the string value from a Term""" if term is None: @@ -28,42 +61,88 @@ def get_term_value(term): return term.iri elif term.type == LITERAL: return term.value + elif term.type == TRIPLE: + # Serialize nested triple to JSON (must match storage format) + return serialize_triple(term.triple) else: # For blank nodes or other types, use id or value return term.id or term.value -def create_term(value, otype=None, dtype=None, lang=None): +def deserialize_term(term_dict): + """Deserialize a term from JSON structure.""" + if term_dict is None: + return None + term_type = term_dict.get("type", "") + if term_type == IRI: + return Term(type=IRI, iri=term_dict.get("iri", "")) + elif term_type == LITERAL: + return Term( + type=LITERAL, + value=term_dict.get("value", ""), + datatype=term_dict.get("datatype", ""), + language=term_dict.get("language", "") + ) + elif term_type == TRIPLE: + # Recursive for nested triples + nested = term_dict.get("triple") + if nested: + return Term( + type=TRIPLE, + triple=Triple( + s=deserialize_term(nested.get("s")), + p=deserialize_term(nested.get("p")), + o=deserialize_term(nested.get("o")), + ) + ) + # Fallback + return Term(type=LITERAL, value=str(term_dict)) + + +def create_term(value, term_type=None, datatype=None, language=None): """ Create a Term from a string value, optionally using type metadata. Args: value: The string value - otype: Object type - 'u' (URI), 'l' (literal), 't' (triple) - dtype: XSD datatype (for literals) - lang: Language tag (for literals) + term_type: 'u' (IRI), 'l' (literal), 't' (triple) + datatype: XSD datatype for literals + language: Language tag for literals - If otype is provided, uses it to determine Term type. - Otherwise falls back to URL detection heuristic. + If term_type is provided, uses it to determine Term type. + Otherwise falls back to URL detection heuristic for object values. """ - if otype is not None: - if otype == 'u': - return Term(type=IRI, iri=value) - elif otype == 'l': - return Term( - type=LITERAL, - value=value, - datatype=dtype or "", - language=lang or "" - ) - elif otype == 't': - # Triple/reification - treat as IRI for now - return Term(type=IRI, iri=value) - else: - # Unknown otype, fall back to heuristic - pass + if term_type == 'u': + return Term(type=IRI, iri=value) + elif term_type == 'l': + return Term( + type=LITERAL, + value=value, + datatype=datatype or "", + language=language or "" + ) + elif term_type == 't': + # Triple/reification - parse JSON and create nested Triple + try: + triple_data = json.loads(value) if isinstance(value, str) else value + if isinstance(triple_data, dict): + return Term( + type=TRIPLE, + triple=Triple( + s=deserialize_term(triple_data.get("s")), + p=deserialize_term(triple_data.get("p")), + o=deserialize_term(triple_data.get("o")), + ) + ) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse triple JSON: {e}") + # Fallback if parsing fails + return Term(type=LITERAL, value=str(value)) + elif term_type is not None: + # Unknown term_type, fall back to heuristic + pass - # Heuristic fallback for backwards compatibility + # Heuristic fallback for backwards compatibility (object values only) if value.startswith("http://") or value.startswith("https://"): return Term(type=IRI, iri=value) else: @@ -98,28 +177,30 @@ class Processor(TriplesQueryService): self.cassandra_password = password self.table = None + def ensure_connection(self, user): + """Ensure we have a connection to the correct keyspace.""" + if user != self.table: + KGClass = EntityCentricKnowledgeGraph + + if self.cassandra_username and self.cassandra_password: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=user, + username=self.cassandra_username, + password=self.cassandra_password + ) + else: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=user, + ) + self.table = user + async def query_triples(self, query): try: - user = query.user - - if user != self.table: - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=query.user, - username=self.cassandra_username, password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=query.user, - ) - self.table = user + self.ensure_connection(query.user) # Extract values from query s_val = get_term_value(query.s) @@ -127,13 +208,13 @@ class Processor(TriplesQueryService): o_val = get_term_value(query.o) g_val = query.g # Already a string or None - # Helper to extract object metadata from result row - def get_o_metadata(t): - """Extract otype/dtype/lang from result row if available""" - otype = getattr(t, 'otype', None) - dtype = getattr(t, 'dtype', None) - lang = getattr(t, 'lang', None) - return otype, dtype, lang + def get_object_metadata(row): + """Extract term type metadata from result row""" + return ( + getattr(row, 'otype', None), + getattr(row, 'dtype', None), + getattr(row, 'lang', None), + ) quads = [] @@ -148,8 +229,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, p_val, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: # SP specified resp = self.tg.get_sp( @@ -158,8 +239,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, p_val, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: # SO specified @@ -169,8 +250,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, t.p, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: # S only resp = self.tg.get_s( @@ -179,8 +260,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((s_val, t.p, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((s_val, t.p, t.o, g, term_type, datatype, language)) else: if p_val is not None: if o_val is not None: @@ -191,8 +272,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, p_val, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: # P only resp = self.tg.get_p( @@ -201,8 +282,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, p_val, t.o, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: # O only @@ -212,8 +293,8 @@ class Processor(TriplesQueryService): ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, t.p, o_val, g, otype, dtype, lang)) + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: # Nothing specified - get all resp = self.tg.get_all( @@ -223,16 +304,24 @@ class Processor(TriplesQueryService): for t in resp: # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - otype, dtype, lang = get_o_metadata(t) - quads.append((t.s, t.p, t.o, g, otype, dtype, lang)) + # Filter by graph + # g_val=None means all graphs (no filter) + # g_val="" means default graph only + # otherwise filter to specific named graph + if g_val is not None: + if g != g_val: + continue + term_type, datatype, language = get_object_metadata(t) + quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) # Convert to Triple objects (with g field) - # Use otype/dtype/lang for proper Term reconstruction if available + # s and p are always IRIs in RDF + # Object uses term_type/datatype/language metadata from database triples = [ Triple( - s=create_term(q[0]), - p=create_term(q[1]), - o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]), + s=create_term(q[0], term_type='u'), + p=create_term(q[1], term_type='u'), + o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]), g=q[3] if q[3] != DEFAULT_GRAPH else None ) for q in quads @@ -245,6 +334,104 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e + async def query_triples_stream(self, query): + """ + Streaming query - yields (batch, is_final) tuples. + Uses Cassandra's paging to fetch results incrementally. + """ + try: + self.ensure_connection(query.user) + + batch_size = query.batch_size if query.batch_size > 0 else 20 + limit = query.limit if query.limit > 0 else 10000 + + # Extract query pattern + s_val = get_term_value(query.s) + p_val = get_term_value(query.p) + o_val = get_term_value(query.o) + g_val = query.g + + def get_object_metadata(row): + """Extract term type metadata from result row""" + return ( + getattr(row, 'otype', None), + getattr(row, 'dtype', None), + getattr(row, 'lang', None), + ) + + # For streaming, we need to execute with fetch_size + # Use the collection table for get_all queries (most common streaming case) + + # Determine which query to use based on pattern + if s_val is None and p_val is None and o_val is None: + # Get all - use collection table with paging + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + params = [query.collection] + else: + # For specific patterns, fall back to non-streaming + # (these typically return small result sets anyway) + async for batch, is_final in self._fallback_stream(query, batch_size): + yield batch, is_final + return + + # Create statement with fetch_size for true streaming + statement = SimpleStatement(cql, fetch_size=batch_size) + result_set = self.tg.session.execute(statement, params) + + batch = [] + count = 0 + + for row in result_set: + if count >= limit: + break + + g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + + # Filter by graph + # g_val=None means all graphs (no filter) + # g_val="" means default graph only + # otherwise filter to specific named graph + if g_val is not None: + if g != g_val: + continue + + term_type, datatype, language = get_object_metadata(row) + + # s and p are always IRIs in RDF + triple = Triple( + s=create_term(row.s, term_type='u'), + p=create_term(row.p, term_type='u'), + o=create_term(row.o, term_type=term_type, datatype=datatype, language=language), + g=g if g != DEFAULT_GRAPH else None + ) + batch.append(triple) + count += 1 + + # Yield batch when full (never mark as final mid-stream) + if len(batch) >= batch_size: + yield batch, False + batch = [] + + # Always yield final batch to signal completion + # This handles: remaining rows, empty result set, or exact batch boundary + yield batch, True + + except Exception as e: + logger.error(f"Exception in streaming query: {e}", exc_info=True) + raise e + + async def _fallback_stream(self, query, batch_size): + """Fallback to non-streaming query with post-hoc batching.""" + triples = await self.query_triples(query) + + for i in range(0, len(triples), batch_size): + batch = triples[i:i + batch_size] + is_final = (i + batch_size >= len(triples)) + yield batch, is_final + + if len(triples) == 0: + yield [], True + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 9f4ad0ff..730a7226 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -1,6 +1,22 @@ import asyncio import logging +import uuid +from datetime import datetime + +# Provenance imports +from trustgraph.provenance import ( + docrag_question_uri, + docrag_grounding_uri, + docrag_exploration_uri, + docrag_synthesis_uri, + docrag_question_triples, + grounding_triples, + docrag_exploration_triples, + docrag_synthesis_triples, + set_graph, + GRAPH_RETRIEVAL, +) # Module logger logger = logging.getLogger(__name__) @@ -19,41 +35,106 @@ class Query: self.verbose = verbose self.doc_limit = doc_limit - async def get_vector(self, query): + async def extract_concepts(self, query): + """Extract key concepts from query for independent embedding.""" + response = await self.rag.prompt_client.prompt( + "extract-concepts", + variables={"query": query} + ) + concepts = [] + if isinstance(response, str): + for line in response.strip().split('\n'): + line = line.strip() + if line: + concepts.append(line) + + # Fallback to raw query if no concepts extracted + if not concepts: + concepts = [query] + + if self.verbose: + logger.debug(f"Extracted concepts: {concepts}") + + return concepts + + async def get_vectors(self, concepts): + """Compute embeddings for a list of concepts.""" if self.verbose: logger.debug("Computing embeddings...") - qembeds = await self.rag.embeddings_client.embed(query) + qembeds = await self.rag.embeddings_client.embed(concepts) if self.verbose: logger.debug("Embeddings computed") return qembeds - async def get_docs(self, query): + async def get_docs(self, concepts): + """ + Get documents (chunks) matching the extracted concepts. - vectors = await self.get_vector(query) + Returns: + tuple: (docs, chunk_ids) where: + - docs: list of document content strings + - chunk_ids: list of chunk IDs that were successfully fetched + """ + vectors = await self.get_vectors(concepts) if self.verbose: - logger.debug("Getting documents...") + logger.debug("Getting chunks from embeddings store...") - docs = await self.rag.doc_embeddings_client.query( - vectors, limit=self.doc_limit, - user=self.user, collection=self.collection, + # Query chunk matches for each concept concurrently + per_concept_limit = max( + 1, self.doc_limit // len(vectors) ) - if self.verbose: - logger.debug("Documents:") - for doc in docs: - logger.debug(f" {doc}") + async def query_concept(vec): + return await self.rag.doc_embeddings_client.query( + vector=vec, limit=per_concept_limit, + user=self.user, collection=self.collection, + ) - return docs + results = await asyncio.gather( + *[query_concept(v) for v in vectors] + ) + + # Deduplicate chunk matches by chunk_id + seen = set() + chunk_matches = [] + for matches in results: + for match in matches: + if match.chunk_id and match.chunk_id not in seen: + seen.add(match.chunk_id) + chunk_matches.append(match) + + if self.verbose: + logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") + + # Fetch chunk content from Garage + docs = [] + chunk_ids = [] + for match in chunk_matches: + if match.chunk_id: + try: + content = await self.rag.fetch_chunk(match.chunk_id, self.user) + docs.append(content) + chunk_ids.append(match.chunk_id) + except Exception as e: + logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}") + + if self.verbose: + logger.debug("Documents fetched:") + for doc in docs: + logger.debug(f" {doc[:100]}...") + + return docs, chunk_ids class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, + fetch_chunk, verbose=False, ): @@ -62,6 +143,7 @@ class DocumentRag: self.prompt_client = prompt_client self.embeddings_client = embeddings_client self.doc_embeddings_client = doc_embeddings_client + self.fetch_chunk = fetch_chunk if self.verbose: logger.debug("DocumentRag initialized") @@ -69,17 +151,69 @@ class DocumentRag: async def query( self, query, user="trustgraph", collection="default", doc_limit=20, streaming=False, chunk_callback=None, + explain_callback=None, save_answer_callback=None, ): + """ + Execute a Document RAG query with optional explainability tracking. + Args: + query: The query string + user: User identifier + collection: Collection identifier + doc_limit: Max chunks to retrieve + streaming: Enable streaming LLM response + chunk_callback: async def callback(chunk, end_of_stream) for streaming + explain_callback: async def callback(triples, explain_id) for explainability + save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian + + Returns: + str: The synthesized answer text + """ if self.verbose: logger.debug("Constructing prompt...") + # Generate explainability URIs upfront + session_id = str(uuid.uuid4()) + q_uri = docrag_question_uri(session_id) + gnd_uri = docrag_grounding_uri(session_id) + exp_uri = docrag_exploration_uri(session_id) + syn_uri = docrag_synthesis_uri(session_id) + + timestamp = datetime.utcnow().isoformat() + "Z" + + # Emit question explainability immediately + if explain_callback: + q_triples = set_graph( + docrag_question_triples(q_uri, query, timestamp), + GRAPH_RETRIEVAL + ) + await explain_callback(q_triples, q_uri) + q = Query( rag=self, user=user, collection=collection, verbose=self.verbose, doc_limit=doc_limit ) - docs = await q.get_docs(query) + # Extract concepts from query (grounding step) + concepts = await q.extract_concepts(query) + + # Emit grounding explainability after concept extraction + if explain_callback: + gnd_triples = set_graph( + grounding_triples(gnd_uri, q_uri, concepts), + GRAPH_RETRIEVAL + ) + await explain_callback(gnd_triples, gnd_uri) + + docs, chunk_ids = await q.get_docs(concepts) + + # Emit exploration explainability after chunks retrieved + if explain_callback: + exp_triples = set_graph( + docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), + GRAPH_RETRIEVAL + ) + await explain_callback(exp_triples, exp_uri) if self.verbose: logger.debug("Invoking LLM...") @@ -87,12 +221,21 @@ class DocumentRag: logger.debug(f"Query: {query}") if streaming and chunk_callback: + # Accumulate chunks for answer storage while forwarding to callback + accumulated_chunks = [] + + async def accumulating_callback(chunk, end_of_stream): + accumulated_chunks.append(chunk) + await chunk_callback(chunk, end_of_stream) + resp = await self.prompt_client.document_prompt( query=query, documents=docs, streaming=True, - chunk_callback=chunk_callback + chunk_callback=accumulating_callback ) + # Combine all chunks into full response + resp = "".join(accumulated_chunks) else: resp = await self.prompt_client.document_prompt( query=query, @@ -102,5 +245,33 @@ class DocumentRag: if self.verbose: logger.debug("Query processing complete") + # Emit synthesis explainability after answer generated + if explain_callback: + synthesis_doc_id = None + answer_text = resp if resp else "" + + # Save answer to librarian + if save_answer_callback and answer_text: + synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer" + try: + await save_answer_callback(synthesis_doc_id, answer_text) + if self.verbose: + logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + synthesis_doc_id = None + + syn_triples = set_graph( + docrag_synthesis_triples( + syn_uri, exp_uri, + document_id=synthesis_doc_id, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(syn_triples, syn_uri) + + if self.verbose: + logger.debug(f"Emitted explain for session {session_id}") + return resp diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 6490562a..9eb32e12 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -4,17 +4,30 @@ Simple RAG service, performs query using document RAG an LLM. Input is query, output is response. """ +import asyncio +import base64 import logging + +import uuid + from ... schema import DocumentRagQuery, DocumentRagResponse, Error +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue +from ... schema import Triples, Metadata +from ... provenance import GRAPH_RETRIEVAL from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec +from ... base import Consumer, Producer +from ... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) default_ident = "document-rag" +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -69,6 +82,161 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "explainability", + schema = Triples, + ) + ) + + # Librarian client for fetching chunk content from Garage + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_chunk_content(self, chunk_id, user, timeout=120): + """Fetch chunk content from librarian/Garage.""" + import uuid + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=chunk_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + # Content is base64 encoded + content = response.content + if isinstance(content, str): + content = content.encode('utf-8') + return base64.b64decode(content).decode("utf-8") + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching chunk {chunk_id}") + + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "DocumentRAG Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_request(self, msg, consumer, flow): try: @@ -77,6 +245,7 @@ class Processor(FlowProcessor): embeddings_client = flow("embeddings-request"), doc_embeddings_client = flow("document-embeddings-request"), prompt_client = flow("prompt-request"), + fetch_chunk = self.fetch_chunk_content, verbose=True, ) @@ -92,6 +261,39 @@ class Processor(FlowProcessor): else: doc_limit = self.doc_limit + # Real-time explainability callback - emits triples and IDs as they're generated + # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + async def send_explainability(triples, explain_id): + # Send triples to explainability queue - stores in same collection with named graph + await flow("explainability").send(Triples( + metadata=Metadata( + id=explain_id, + user=v.user, + collection=v.collection, # Store in user's collection + ), + triples=triples, + )) + + # Send explain ID and graph to response queue + await flow("response").send( + DocumentRagResponse( + response=None, + explain_id=explain_id, + explain_graph=GRAPH_RETRIEVAL, + message_type="explain", + ), + properties={"id": id} + ) + + # Callback to save answer content to librarian + async def save_answer(doc_id, answer_text): + await self.save_answer_content( + doc_id=doc_id, + user=v.user, + content=answer_text, + title=f"DocumentRAG Answer: {v.query[:50]}...", + ) + # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks @@ -101,6 +303,7 @@ class Processor(FlowProcessor): DocumentRagResponse( response=chunk, end_of_stream=end_of_stream, + message_type="chunk", error=None ), properties={"id": id} @@ -115,6 +318,18 @@ class Processor(FlowProcessor): doc_limit=doc_limit, streaming=True, chunk_callback=send_chunk, + explain_callback=send_explainability, + save_answer_callback=save_answer, + ) + + # Send end_of_session to signal entire session is complete + await flow("response").send( + DocumentRagResponse( + response=None, + end_of_session=True, + message_type="end", + ), + properties={"id": id} ) else: # Non-streaming path (existing behavior) @@ -122,7 +337,9 @@ class Processor(FlowProcessor): v.query, user=v.user, collection=v.collection, - doc_limit=doc_limit + doc_limit=doc_limit, + explain_callback=send_explainability, + save_answer_callback=save_answer, ) await flow("response").send( diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 7ccba248..22d4fc1b 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -1,14 +1,57 @@ import asyncio +import hashlib +import json import logging import time +import uuid from collections import OrderedDict +from datetime import datetime + +from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE + +# Provenance imports +from trustgraph.provenance import ( + question_uri, + grounding_uri as make_grounding_uri, + exploration_uri as make_exploration_uri, + focus_uri as make_focus_uri, + synthesis_uri as make_synthesis_uri, + question_triples, + grounding_triples, + exploration_triples, + focus_triples, + synthesis_triples, + set_graph, + GRAPH_RETRIEVAL, GRAPH_SOURCE, + TG_CONTAINS, PROV_WAS_DERIVED_FROM, +) # Module logger logger = logging.getLogger(__name__) LABEL="http://www.w3.org/2000/01/rdf-schema#label" + +def term_to_string(term): + """Extract string value from a Term object.""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + # Fallback + return term.iri or term.value or str(term) + + +def edge_id(s, p, o): + """Generate an 8-character hash ID for an edge (s, p, o).""" + edge_str = f"{s}|{p}|{o}" + return hashlib.sha256(edge_str.encode()).hexdigest()[:8] + + + class LRUCacheWithTTL: """LRU cache with TTL for label caching @@ -67,12 +110,32 @@ class Query: self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length - async def get_vector(self, query): + async def extract_concepts(self, query): + """Extract key concepts from query for independent embedding.""" + response = await self.rag.prompt_client.prompt( + "extract-concepts", + variables={"query": query} + ) + concepts = [] + if isinstance(response, str): + for line in response.strip().split('\n'): + line = line.strip() + if line: + concepts.append(line) + + if self.verbose: + logger.debug(f"Extracted concepts: {concepts}") + + # Fall back to raw query if extraction returns nothing + return concepts if concepts else [query] + + async def get_vectors(self, concepts): + """Embed multiple concepts concurrently.""" if self.verbose: logger.debug("Computing embeddings...") - qembeds = await self.rag.embeddings_client.embed(query) + qembeds = await self.rag.embeddings_client.embed(concepts) if self.verbose: logger.debug("Done.") @@ -80,28 +143,55 @@ class Query: return qembeds async def get_entities(self, query): + """ + Extract concepts from query, embed them, and retrieve matching entities. - vectors = await self.get_vector(query) + Returns: + tuple: (entities, concepts) where entities is a list of entity URI + strings and concepts is the list of concept strings extracted + from the query. + """ + + concepts = await self.extract_concepts(query) + + vectors = await self.get_vectors(concepts) if self.verbose: logger.debug("Getting entities...") - entities = await self.rag.graph_embeddings_client.query( - vectors=vectors, limit=self.entity_limit, - user=self.user, collection=self.collection, + # Query entity matches for each concept concurrently + per_concept_limit = max( + 1, self.entity_limit // len(vectors) ) - entities = [ - str(e) - for e in entities + entity_tasks = [ + self.rag.graph_embeddings_client.query( + vector=v, limit=per_concept_limit, + user=self.user, collection=self.collection, + ) + for v in vectors ] + results = await asyncio.gather(*entity_tasks, return_exceptions=True) + + # Deduplicate while preserving order + seen = set() + entities = [] + for result in results: + if isinstance(result, Exception) or not result: + continue + for e in result: + entity = term_to_string(e.entity) + if entity not in seen: + seen.add(entity) + entities.append(entity) + if self.verbose: logger.debug("Entities:") for ent in entities: logger.debug(f" {ent}") - return entities + return entities, concepts async def maybe_label(self, e): @@ -117,6 +207,7 @@ class Query: res = await self.rag.triples_client.query( s=e, p=LABEL, o=None, limit=1, user=self.user, collection=self.collection, + g="", ) if len(res) == 0: @@ -128,26 +219,29 @@ class Query: return label async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently""" + """Execute triple queries for multiple entities concurrently using streaming""" tasks = [] for entity in entities: - # Create concurrent tasks for all 3 query types per entity + # Create concurrent streaming tasks for all 3 query types per entity tasks.extend([ - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, g="", ), - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, g="", ), - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, g="", ) ]) @@ -157,7 +251,7 @@ class Query: # Combine all results all_triples = [] for result in results: - if not isinstance(result, Exception): + if not isinstance(result, Exception) and result is not None: all_triples.extend(result) return all_triples @@ -220,8 +314,16 @@ class Query: subgraph.update(batch_result) async def get_subgraph(self, query): + """ + Get subgraph by extracting concepts, finding entities, and traversing. - entities = await self.get_entities(query) + Returns: + tuple: (subgraph, entities, concepts) where subgraph is a list of + (s, p, o) tuples, entities is the seed entity list, and concepts + is the extracted concept list. + """ + + entities, concepts = await self.get_entities(query) if self.verbose: logger.debug("Getting subgraph...") @@ -229,7 +331,7 @@ class Query: # Use optimized batch traversal instead of sequential processing subgraph = await self.follow_edges_batch(entities, self.max_path_length) - return list(subgraph) + return list(subgraph), entities, concepts async def resolve_labels_batch(self, entities): """Resolve labels for multiple entities in parallel""" @@ -240,8 +342,17 @@ class Query: return await asyncio.gather(*tasks, return_exceptions=True) async def get_labelgraph(self, query): + """ + Get subgraph with labels resolved for display. - subgraph = await self.get_subgraph(query) + Returns: + tuple: (labeled_edges, uri_map, entities, concepts) where: + - labeled_edges: list of (label_s, label_p, label_o) tuples + - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) + - entities: list of seed entity URI strings + - concepts: list of concept strings extracted from query + """ + subgraph, entities, concepts = await self.get_subgraph(query) # Filter out label triples filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] @@ -263,28 +374,151 @@ class Query: else: label_map[entity] = entity # Fallback to entity itself - # Apply labels to subgraph - sg2 = [] + # Apply labels to subgraph and build URI mapping + labeled_edges = [] + uri_map = {} # Maps edge_id of labeled edge -> original URI triple + for s, p, o in filtered_subgraph: labeled_triple = ( label_map.get(s, s), label_map.get(p, p), label_map.get(o, o) ) - sg2.append(labeled_triple) + labeled_edges.append(labeled_triple) - sg2 = sg2[0:self.max_subgraph_size] + # Map from labeled edge ID to original URIs + labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) + uri_map[labeled_eid] = (s, p, o) + + labeled_edges = labeled_edges[0:self.max_subgraph_size] if self.verbose: logger.debug("Subgraph:") - for edge in sg2: + for edge in labeled_edges: logger.debug(f" {str(edge)}") if self.verbose: logger.debug("Done.") - return sg2 - + return labeled_edges, uri_map, entities, concepts + + async def trace_source_documents(self, edge_uris): + """ + Trace selected edges back to their source documents via provenance. + + Follows the chain: edge → subgraph (via tg:contains) → chunk → + page → document (via prov:wasDerivedFrom), all in urn:graph:source. + + Args: + edge_uris: List of (s, p, o) URI string tuples + + Returns: + List of unique document titles + """ + # Step 1: Find subgraphs containing these edges via tg:contains + subgraph_tasks = [] + for s, p, o in edge_uris: + quoted = Term( + type=TRIPLE, + triple=SchemaTriple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o), + ) + ) + subgraph_tasks.append( + self.rag.triples_client.query( + s=None, p=TG_CONTAINS, o=quoted, limit=1, + user=self.user, collection=self.collection, + g=GRAPH_SOURCE, + ) + ) + + subgraph_results = await asyncio.gather( + *subgraph_tasks, return_exceptions=True + ) + + # Collect unique subgraph URIs + subgraph_uris = set() + for result in subgraph_results: + if isinstance(result, Exception) or not result: + continue + for triple in result: + subgraph_uris.add(str(triple.s)) + + if not subgraph_uris: + return [] + + # Step 2: Walk prov:wasDerivedFrom chain to find documents + # Each level: query ?entity prov:wasDerivedFrom ?parent + # Stop when we find entities typed tg:Document + current_uris = subgraph_uris + doc_uris = set() + + for depth in range(4): # Max depth: subgraph → chunk → page → doc + if not current_uris: + break + + derivation_tasks = [ + self.rag.triples_client.query( + s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5, + user=self.user, collection=self.collection, + g=GRAPH_SOURCE, + ) + for uri in current_uris + ] + + derivation_results = await asyncio.gather( + *derivation_tasks, return_exceptions=True + ) + + # URIs with no parent are root documents + next_uris = set() + for uri, result in zip(current_uris, derivation_results): + if isinstance(result, Exception) or not result: + doc_uris.add(uri) + continue + for triple in result: + next_uris.add(str(triple.o)) + + current_uris = next_uris - doc_uris + + if not doc_uris: + return [] + + # Step 3: Get all document metadata properties + # Skip structural predicates that aren't useful context + SKIP_PREDICATES = { + PROV_WAS_DERIVED_FROM, + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + } + + metadata_tasks = [ + self.rag.triples_client.query( + s=uri, p=None, o=None, limit=50, + user=self.user, collection=self.collection, + ) + for uri in doc_uris + ] + + metadata_results = await asyncio.gather( + *metadata_tasks, return_exceptions=True + ) + + doc_edges = [] + for result in metadata_results: + if isinstance(result, Exception) or not result: + continue + for triple in result: + p = str(triple.p) + if p in SKIP_PREDICATES: + continue + doc_edges.append(( + str(triple.s), p, str(triple.o) + )) + + return doc_edges + class GraphRag: """ CRITICAL SECURITY: @@ -316,12 +550,50 @@ class GraphRag: async def query( self, query, user = "trustgraph", collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, streaming = False, chunk_callback = None, + max_path_length = 2, edge_limit = 25, streaming = False, + chunk_callback = None, + explain_callback = None, save_answer_callback = None, ): + """ + Execute a GraphRAG query with real-time explainability tracking. + Args: + query: The query string + user: User identifier + collection: Collection identifier + entity_limit: Max entities to retrieve + triple_limit: Max triples per entity + max_subgraph_size: Max edges in subgraph + max_path_length: Max hops from seed entities + streaming: Enable streaming LLM response + chunk_callback: async def callback(chunk, end_of_stream) for streaming + explain_callback: async def callback(triples, explain_id) for real-time explainability + save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian + + Returns: + str: The synthesized answer text + """ if self.verbose: logger.debug("Constructing prompt...") + # Generate explainability URIs upfront + session_id = str(uuid.uuid4()) + q_uri = question_uri(session_id) + gnd_uri = make_grounding_uri(session_id) + exp_uri = make_exploration_uri(session_id) + foc_uri = make_focus_uri(session_id) + syn_uri = make_synthesis_uri(session_id) + + timestamp = datetime.utcnow().isoformat() + "Z" + + # Emit question explainability immediately + if explain_callback: + q_triples = set_graph( + question_triples(q_uri, query, timestamp), + GRAPH_RETRIEVAL + ) + await explain_callback(q_triples, q_uri) + q = Query( rag = self, user = user, collection = collection, verbose = self.verbose, entity_limit = entity_limit, @@ -330,24 +602,262 @@ class GraphRag: max_path_length = max_path_length, ) - kg = await q.get_labelgraph(query) + kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) + + # Emit grounding explain after concept extraction + if explain_callback: + gnd_triples = set_graph( + grounding_triples(gnd_uri, q_uri, concepts), + GRAPH_RETRIEVAL + ) + await explain_callback(gnd_triples, gnd_uri) + + # Emit exploration explain after graph retrieval completes + if explain_callback: + exp_triples = set_graph( + exploration_triples( + exp_uri, gnd_uri, len(kg), + entities=seed_entities, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(exp_triples, exp_uri) if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Knowledge graph: {kg}") logger.debug(f"Query: {query}") - if streaming and chunk_callback: - resp = await self.prompt_client.kg_prompt( - query, kg, - streaming=True, - chunk_callback=chunk_callback + # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} + # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) + edge_map = {} + edges_with_ids = [] + for s, p, o in kg: + eid = edge_id(s, p, o) + edge_map[eid] = (s, p, o) + edges_with_ids.append({ + "id": eid, + "s": s, + "p": p, + "o": o + }) + + if self.verbose: + logger.debug(f"Built edge map with {len(edge_map)} edges") + + # Step 1a: Edge Scoring - LLM scores edges for relevance + scoring_response = await self.prompt_client.prompt( + "kg-edge-scoring", + variables={ + "query": query, + "knowledge": edges_with_ids + } + ) + + if self.verbose: + logger.debug(f"Edge scoring response: {scoring_response}") + + # Parse scoring response to get edge IDs with scores + scored_edges = [] + + def parse_scored_edge(obj): + if isinstance(obj, dict) and "id" in obj and "score" in obj: + try: + score = int(obj["score"]) + except (ValueError, TypeError): + score = 0 + scored_edges.append({"id": obj["id"], "score": score}) + + if isinstance(scoring_response, list): + for obj in scoring_response: + parse_scored_edge(obj) + elif isinstance(scoring_response, str): + for line in scoring_response.strip().split('\n'): + line = line.strip() + if not line: + continue + try: + parse_scored_edge(json.loads(line)) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge scoring line: {line}" + ) + + # Select top N edges by score + scored_edges.sort(key=lambda x: x["score"], reverse=True) + top_edges = scored_edges[:edge_limit] + selected_ids = {e["id"] for e in top_edges} + + if self.verbose: + logger.debug( + f"Scored {len(scored_edges)} edges, " + f"selected top {len(selected_ids)}" ) + + # Filter to selected edges + selected_edges = [] + for eid in selected_ids: + if eid in edge_map: + selected_edges.append(edge_map[eid]) + + # Step 1b: Edge Reasoning + Document Tracing (concurrent) + selected_edges_with_ids = [ + {"id": eid, "s": s, "p": p, "o": o} + for eid in selected_ids + if eid in edge_map + for s, p, o in [edge_map[eid]] + ] + + # Collect selected edge URIs for document tracing + selected_edge_uris = [ + uri_map[eid] + for eid in selected_ids + if eid in uri_map + ] + + # Run reasoning and document tracing concurrently + reasoning_task = self.prompt_client.prompt( + "kg-edge-reasoning", + variables={ + "query": query, + "knowledge": selected_edges_with_ids + } + ) + doc_trace_task = q.trace_source_documents(selected_edge_uris) + + reasoning_response, source_documents = await asyncio.gather( + reasoning_task, doc_trace_task, return_exceptions=True + ) + + # Handle exceptions from gather + if isinstance(reasoning_response, Exception): + logger.warning( + f"Edge reasoning failed: {reasoning_response}" + ) + reasoning_response = "" + if isinstance(source_documents, Exception): + logger.warning( + f"Document tracing failed: {source_documents}" + ) + source_documents = [] + + + if self.verbose: + logger.debug(f"Edge reasoning response: {reasoning_response}") + + # Parse reasoning response and build explainability data + reasoning_map = {} + + def parse_reasoning(obj): + if isinstance(obj, dict) and "id" in obj: + reasoning_map[obj["id"]] = obj.get("reasoning", "") + + if isinstance(reasoning_response, list): + for obj in reasoning_response: + parse_reasoning(obj) + elif isinstance(reasoning_response, str): + for line in reasoning_response.strip().split('\n'): + line = line.strip() + if not line: + continue + try: + parse_reasoning(json.loads(line)) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge reasoning line: {line}" + ) + + selected_edges_with_reasoning = [] + for eid in selected_ids: + if eid in uri_map: + uri_s, uri_p, uri_o = uri_map[eid] + selected_edges_with_reasoning.append({ + "edge": (uri_s, uri_p, uri_o), + "reasoning": reasoning_map.get(eid, ""), + }) + + if self.verbose: + logger.debug(f"Filtered to {len(selected_edges)} edges") + + # Emit focus explain after edge selection completes + if explain_callback: + foc_triples = set_graph( + focus_triples( + foc_uri, exp_uri, selected_edges_with_reasoning, session_id + ), + GRAPH_RETRIEVAL + ) + await explain_callback(foc_triples, foc_uri) + + # Step 2: Synthesis - LLM generates answer from selected edges only + selected_edge_dicts = [ + {"s": s, "p": p, "o": o} + for s, p, o in selected_edges + ] + + # Add source document metadata as knowledge edges + for s, p, o in source_documents: + selected_edge_dicts.append({ + "s": s, "p": p, "o": o, + }) + + synthesis_variables = { + "query": query, + "knowledge": selected_edge_dicts, + } + + if streaming and chunk_callback: + # Accumulate chunks for answer storage while forwarding to callback + accumulated_chunks = [] + + async def accumulating_callback(chunk, end_of_stream): + accumulated_chunks.append(chunk) + await chunk_callback(chunk, end_of_stream) + + await self.prompt_client.prompt( + "kg-synthesis", + variables=synthesis_variables, + streaming=True, + chunk_callback=accumulating_callback + ) + # Combine all chunks into full response + resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.kg_prompt(query, kg) + resp = await self.prompt_client.prompt( + "kg-synthesis", + variables=synthesis_variables, + ) if self.verbose: logger.debug("Query processing complete") + # Emit synthesis explain after synthesis completes + if explain_callback: + synthesis_doc_id = None + answer_text = resp if resp else "" + + # Save answer to librarian + if save_answer_callback and answer_text: + synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" + try: + await save_answer_callback(synthesis_doc_id, answer_text) + if self.verbose: + logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + synthesis_doc_id = None + + syn_triples = set_graph( + synthesis_triples( + syn_uri, foc_uri, + document_id=synthesis_doc_id, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(syn_triples, syn_uri) + + if self.verbose: + logger.debug(f"Emitted explain for session {session_id}") + return resp diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index d8bfbddb..ec4a806c 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -4,18 +4,29 @@ Simple RAG service, performs query using graph RAG an LLM. Input is query, output is response. """ +import asyncio +import base64 import logging +import uuid + from ... schema import GraphRagQuery, GraphRagResponse, Error +from ... schema import Triples, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue +from ... provenance import GRAPH_RETRIEVAL from . graph_rag import GraphRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec +from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) default_ident = "graph-rag" default_concurrency = 1 +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -28,6 +39,7 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) + edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( **params | { @@ -37,6 +49,7 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, + "edge_limit": edge_limit, } ) @@ -44,6 +57,7 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length + self.default_edge_limit = edge_limit # CRITICAL SECURITY: NEVER share data between users or collections # Each user/collection combination MUST have isolated data access @@ -93,10 +107,163 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "explainability", + schema = Triples, + ) + ) + + # Librarian client for storing answer content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_librarian_requests = {} + + logger.info("Graph RAG service initialized") + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_librarian_requests: + future = self.pending_librarian_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + """ + Save answer content to the librarian. + + Args: + doc_id: ID for the answer document + user: User ID + content: Answer text content + title: Optional title + timeout: Request timeout in seconds + + Returns: + The document ID on success + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "GraphRAG Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_librarian_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_librarian_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + async def on_request(self, msg, consumer, flow): try: + v = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.info(f"Handling input {id}...") + + # Track explainability refs for end_of_session signaling + explainability_refs_emitted = [] + + # Real-time explainability callback - emits triples and IDs as they're generated + # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + async def send_explainability(triples, explain_id): + # Send triples to explainability queue - stores in same collection with named graph + await flow("explainability").send(Triples( + metadata=Metadata( + id=explain_id, + user=v.user, + collection=v.collection, # Store in user's collection, not separate explainability collection + ), + triples=triples, + )) + + # Send explain ID and graph to response queue + await flow("response").send( + GraphRagResponse( + message_type="explain", + explain_id=explain_id, + explain_graph=GRAPH_RETRIEVAL, + ), + properties={"id": id} + ) + + explainability_refs_emitted.append(explain_id) + # CRITICAL SECURITY: Create new GraphRag instance per request # This ensures proper isolation between users and collections # Flow clients are request-scoped and must not be shared @@ -108,13 +275,6 @@ class Processor(FlowProcessor): verbose=True, ) - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - logger.info(f"Handling input {id}...") - if v.entity_limit: entity_limit = v.entity_limit else: @@ -135,6 +295,20 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length + if v.edge_limit: + edge_limit = v.edge_limit + else: + edge_limit = self.default_edge_limit + + # Callback to save answer content to librarian + async def save_answer(doc_id, answer_text): + await self.save_answer_content( + doc_id=doc_id, + user=v.user, + content=answer_text, + title=f"GraphRAG Answer: {v.query[:50]}...", + ) + # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks @@ -142,6 +316,7 @@ class Processor(FlowProcessor): async def send_chunk(chunk, end_of_stream): await flow("response").send( GraphRagResponse( + message_type="chunk", response=chunk, end_of_stream=end_of_stream, error=None @@ -149,34 +324,52 @@ class Processor(FlowProcessor): properties={"id": id} ) - # Query with streaming enabled - # All chunks (including final one with end_of_stream=True) are sent via callback - await rag.query( - query = v.query, user = v.user, collection = v.collection, - entity_limit = entity_limit, triple_limit = triple_limit, - max_subgraph_size = max_subgraph_size, - max_path_length = max_path_length, - streaming = True, - chunk_callback = send_chunk, - ) - else: - # Non-streaming path (existing behavior) + # Query with streaming and real-time explain response = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, + streaming = True, + chunk_callback = send_chunk, + explain_callback = send_explainability, + save_answer_callback = save_answer, ) + else: + # Non-streaming path with real-time explain + response = await rag.query( + query = v.query, user = v.user, collection = v.collection, + entity_limit = entity_limit, triple_limit = triple_limit, + max_subgraph_size = max_subgraph_size, + max_path_length = max_path_length, + edge_limit = edge_limit, + explain_callback = send_explainability, + save_answer_callback = save_answer, + ) + + # Send chunk with response await flow("response").send( GraphRagResponse( - response = response, - end_of_stream = True, - error = None + message_type="chunk", + response=response, + end_of_stream=True, + error=None, ), - properties = {"id": id} + properties={"id": id} ) + # Send final message to close session + await flow("response").send( + GraphRagResponse( + message_type="chunk", + response="", + end_of_session=True, + ), + properties={"id": id} + ) + logger.info("Request processing complete") except Exception as e: @@ -185,22 +378,18 @@ class Processor(FlowProcessor): logger.debug("Sending error response...") - # Send error response with end_of_stream flag if streaming was requested - error_response = GraphRagResponse( - response = None, - error = Error( - type = "graph-rag-error", - message = str(e), - ), - ) - - # If streaming was requested, indicate stream end - if v.streaming: - error_response.end_of_stream = True - + # Send error response and close session await flow("response").send( - error_response, - properties = {"id": id} + GraphRagResponse( + message_type="chunk", + error=Error( + type="graph-rag-error", + message=str(e), + ), + end_of_stream=True, + end_of_session=True, + ), + properties={"id": id} ) @staticmethod @@ -243,6 +432,9 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) + # Note: Explainability triples are now stored in the user's collection + # with the named graph urn:graph:retrieval (no separate collection needed) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index ae869413..e282f876 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -37,14 +37,14 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - if emb.chunk is None or emb.chunk == b"": continue + chunk_id = emb.chunk_id + if chunk_id == "": + continue - chunk = emb.chunk.decode("utf-8") - if chunk == "": continue - - for vec in emb.vectors: + vec = emb.vector + if vec: self.vecstore.insert( - vec, chunk, + vec, chunk_id, message.metadata.user, message.metadata.collection ) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index a0e52253..ea091d35 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -101,40 +101,41 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - if emb.chunk is None or emb.chunk == b"": continue + chunk_id = emb.chunk_id + if chunk_id == "": + continue - chunk = emb.chunk.decode("utf-8") - if chunk == "": continue + vec = emb.vector + if not vec: + continue - for vec in emb.vectors: + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + index = self.pinecone.Index(index_name) - index = self.pinecone.Index(index_name) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + records = [ + { + "id": vector_id, + "values": vec, + "metadata": { "chunk_id": chunk_id }, + } + ] - records = [ - { - "id": vector_id, - "values": vec, - "metadata": { "doc": chunk }, - } - ] - - index.upsert( - vectors = records, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index cb978048..a87f2128 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -52,41 +52,44 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - chunk = emb.chunk.decode("utf-8") - if chunk == "": return + chunk_id = emb.chunk_id + if chunk_id == "": + continue - for vec in emb.vectors: + vec = emb.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "doc": chunk, - } - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "chunk_id": chunk_id, + } + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 21aa21e6..0f27adf9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -53,11 +53,13 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): entity_value = get_term_value(entity.entity) if entity_value != "" and entity_value is not None: - for vec in entity.vectors: + vec = entity.vector + if vec: self.vecstore.insert( vec, entity_value, message.metadata.user, - message.metadata.collection + message.metadata.collection, + chunk_id=entity.chunk_id or "", ) @staticmethod diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index c4b0065b..d907e873 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -119,35 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - index = self.pinecone.Index(index_name) + index = self.pinecone.Index(index_name) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - records = [ - { - "id": vector_id, - "values": vec, - "metadata": { "entity": entity_value }, - } - ] + metadata = {"entity": entity_value} + if entity.chunk_id: + metadata["chunk_id"] = entity.chunk_id - index.upsert( - vectors = records, - ) + records = [ + { + "id": vector_id, + "values": vec, + "metadata": metadata, + } + ] + + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 0da59bb9..f887d487 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -71,38 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "entity": entity_value, - } - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + payload = { + "entity": entity_value, + } + if entity.chunk_id: + payload["chunk_id"] = entity.chunk_id + + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload=payload, + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 29848c4c..42e59012 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor): qdrant_collection = None for row_emb in embeddings.embeddings: - if not row_emb.vectors: + vector = row_emb.vector + if not vector: logger.warning( - f"No vectors for index {row_emb.index_name} - skipping" + f"No vector for index {row_emb.index_name} - skipping" ) continue - # Use first vector (there may be multiple from different models) - for vector in row_emb.vectors: - dimension = len(vector) + dimension = len(vector) - # Create/get collection name (lazily on first vector) - if qdrant_collection is None: - qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension - ) - self.ensure_collection(qdrant_collection, dimension) - - # Write to Qdrant - self.qdrant.upsert( - collection_name=qdrant_collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vector, - payload={ - "index_name": row_emb.index_name, - "index_value": row_emb.index_value, - "text": row_emb.text - } - ) - ] + # Create/get collection name (lazily on first vector) + if qdrant_collection is None: + qdrant_collection = self.get_collection_name( + user, collection, schema_name, dimension ) - embeddings_written += 1 + self.ensure_collection(qdrant_collection, dimension) + + # Write to Qdrant + self.qdrant.upsert( + collection_name=qdrant_collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "index_name": row_emb.index_name, + "index_value": row_emb.index_value, + "text": row_emb.text + } + ) + ] + ) + embeddings_written += 1 logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 5bc842de..ab13ccbc 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -9,6 +9,7 @@ import os import argparse import time import logging +import json from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, DEFAULT_GRAPH @@ -25,6 +26,37 @@ logger = logging.getLogger(__name__) default_ident = "triples-write" +def serialize_triple(triple): + """Serialize a Triple object to JSON for storage.""" + if triple is None: + return None + + def term_to_dict(term): + if term is None: + return None + + result = {"type": term.type} + if term.type == IRI: + result["iri"] = term.iri + elif term.type == LITERAL: + result["value"] = term.value + if term.datatype: + result["datatype"] = term.datatype + if term.language: + result["language"] = term.language + elif term.type == BLANK: + result["id"] = term.id + elif term.type == TRIPLE: + result["triple"] = serialize_triple(term.triple) + return result + + return json.dumps({ + "s": term_to_dict(triple.s), + "p": term_to_dict(triple.p), + "o": term_to_dict(triple.o), + }) + + def get_term_value(term): """Extract the string value from a Term""" if term is None: @@ -33,6 +65,9 @@ def get_term_value(term): return term.iri elif term.type == LITERAL: return term.value + elif term.type == TRIPLE: + # Serialize nested triple as JSON + return serialize_triple(term.triple) else: # For blank nodes or other types, use id or value return term.id or term.value diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 6ea16499..430dc3c9 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -114,7 +114,7 @@ class KnowledgeTableStore: entity_embeddings list< tuple< tuple, - list> + list > >, PRIMARY KEY ((user, document_id), id) @@ -140,7 +140,7 @@ class KnowledgeTableStore: chunks list< tuple< blob, - list> + list > >, PRIMARY KEY ((user, document_id), id) @@ -218,16 +218,6 @@ class KnowledgeTableStore: when = int(time.time() * 1000) - if m.metadata.metadata: - metadata = [ - ( - *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) - ) - for v in m.metadata.metadata - ] - else: - metadata = [] - triples = [ ( *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) @@ -243,8 +233,8 @@ class KnowledgeTableStore: self.insert_triples_stmt, ( uuid.uuid4(), m.metadata.user, - m.metadata.id, when, - metadata, triples, + m.metadata.root or m.metadata.id, when, + [], triples, ) ) @@ -259,20 +249,10 @@ class KnowledgeTableStore: when = int(time.time() * 1000) - if m.metadata.metadata: - metadata = [ - ( - *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) - ) - for v in m.metadata.metadata - ] - else: - metadata = [] - entities = [ ( term_to_tuple(v.entity), - v.vectors + v.vector ) for v in m.entities ] @@ -285,8 +265,8 @@ class KnowledgeTableStore: self.insert_graph_embeddings_stmt, ( uuid.uuid4(), m.metadata.user, - m.metadata.id, when, - metadata, entities, + m.metadata.root or m.metadata.id, when, + [], entities, ) ) @@ -301,20 +281,10 @@ class KnowledgeTableStore: when = int(time.time() * 1000) - if m.metadata.metadata: - metadata = [ - ( - *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) - ) - for v in m.metadata.metadata - ] - else: - metadata = [] - chunks = [ ( - v.chunk, - v.vectors, + v.chunk_id, + v.vector, ) for v in m.chunks ] @@ -327,8 +297,8 @@ class KnowledgeTableStore: self.insert_document_embeddings_stmt, ( uuid.uuid4(), m.metadata.user, - m.metadata.id, when, - metadata, chunks, + m.metadata.root or m.metadata.id, when, + [], chunks, ) ) @@ -423,18 +393,6 @@ class KnowledgeTableStore: for row in resp: - if row[2]: - metadata = [ - Triple( - s = tuple_to_term(elt[0], elt[1]), - p = tuple_to_term(elt[2], elt[3]), - o = tuple_to_term(elt[4], elt[5]), - ) - for elt in row[2] - ] - else: - metadata = [] - if row[3]: triples = [ Triple( @@ -453,7 +411,6 @@ class KnowledgeTableStore: id = document_id, user = user, collection = "default", # FIXME: What to put here? - metadata = metadata, ), triples = triples ) @@ -482,18 +439,6 @@ class KnowledgeTableStore: for row in resp: - if row[2]: - metadata = [ - Triple( - s = tuple_to_term(elt[0], elt[1]), - p = tuple_to_term(elt[2], elt[3]), - o = tuple_to_term(elt[4], elt[5]), - ) - for elt in row[2] - ] - else: - metadata = [] - if row[3]: entities = [ EntityEmbeddings( @@ -511,7 +456,6 @@ class KnowledgeTableStore: id = document_id, user = user, collection = "default", # FIXME: What to put here? - metadata = metadata, ), entities = entities ) diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 8bbe2bad..11dd9022 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -112,6 +112,34 @@ class LibraryTableStore: ON document (object_id) """); + # Add parent_id and document_type columns for child document support + logger.debug("document table parent_id column...") + + try: + self.cassandra.execute(""" + ALTER TABLE document ADD parent_id text + """); + except Exception as e: + # Column may already exist + if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): + logger.debug(f"parent_id column may already exist: {e}") + + try: + self.cassandra.execute(""" + ALTER TABLE document ADD document_type text + """); + except Exception as e: + # Column may already exist + if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): + logger.debug(f"document_type column may already exist: {e}") + + logger.debug("document parent index...") + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS document_parent + ON document (parent_id) + """); + logger.debug("processing table...") self.cassandra.execute(""" @@ -127,6 +155,32 @@ class LibraryTableStore: ); """); + logger.debug("upload_session table...") + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS upload_session ( + upload_id text PRIMARY KEY, + user text, + document_id text, + document_metadata text, + s3_upload_id text, + object_id uuid, + total_size bigint, + chunk_size int, + total_chunks int, + chunks_received map, + created_at timestamp, + updated_at timestamp + ) WITH default_time_to_live = 86400; + """); + + logger.debug("upload_session user index...") + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS upload_session_user + ON upload_session (user) + """); + logger.info("Cassandra schema OK.") def prepare_statements(self): @@ -136,9 +190,10 @@ class LibraryTableStore: ( id, user, time, kind, title, comments, - metadata, tags, object_id + metadata, tags, object_id, + parent_id, document_type ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """) self.update_document_stmt = self.cassandra.prepare(""" @@ -149,7 +204,8 @@ class LibraryTableStore: """) self.get_document_stmt = self.cassandra.prepare(""" - SELECT time, kind, title, comments, metadata, tags, object_id + SELECT time, kind, title, comments, metadata, tags, object_id, + parent_id, document_type FROM document WHERE user = ? AND id = ? """) @@ -168,14 +224,16 @@ class LibraryTableStore: self.list_document_stmt = self.cassandra.prepare(""" SELECT - id, time, kind, title, comments, metadata, tags, object_id + id, time, kind, title, comments, metadata, tags, object_id, + parent_id, document_type FROM document WHERE user = ? """) self.list_document_by_tag_stmt = self.cassandra.prepare(""" SELECT - id, time, kind, title, comments, metadata, tags, object_id + id, time, kind, title, comments, metadata, tags, object_id, + parent_id, document_type FROM document WHERE user = ? AND tags CONTAINS ? ALLOW FILTERING @@ -210,6 +268,57 @@ class LibraryTableStore: WHERE user = ? """) + # Upload session prepared statements + self.insert_upload_session_stmt = self.cassandra.prepare(""" + INSERT INTO upload_session + ( + upload_id, user, document_id, document_metadata, + s3_upload_id, object_id, total_size, chunk_size, + total_chunks, chunks_received, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """) + + self.get_upload_session_stmt = self.cassandra.prepare(""" + SELECT + upload_id, user, document_id, document_metadata, + s3_upload_id, object_id, total_size, chunk_size, + total_chunks, chunks_received, created_at, updated_at + FROM upload_session + WHERE upload_id = ? + """) + + self.update_upload_session_chunk_stmt = self.cassandra.prepare(""" + UPDATE upload_session + SET chunks_received = chunks_received + ?, + updated_at = ? + WHERE upload_id = ? + """) + + self.delete_upload_session_stmt = self.cassandra.prepare(""" + DELETE FROM upload_session + WHERE upload_id = ? + """) + + self.list_upload_sessions_stmt = self.cassandra.prepare(""" + SELECT + upload_id, document_id, document_metadata, + total_size, chunk_size, total_chunks, + chunks_received, created_at, updated_at + FROM upload_session + WHERE user = ? + """) + + # Child document queries + self.list_children_stmt = self.cassandra.prepare(""" + SELECT + id, user, time, kind, title, comments, metadata, tags, + object_id, parent_id, document_type + FROM document + WHERE parent_id = ? + ALLOW FILTERING + """) + async def document_exists(self, user, id): resp = self.cassandra.execute( @@ -236,6 +345,10 @@ class LibraryTableStore: for v in document.metadata ] + # Get parent_id and document_type from document, defaulting if not set + parent_id = getattr(document, 'parent_id', '') or '' + document_type = getattr(document, 'document_type', 'source') or 'source' + while True: try: @@ -245,7 +358,8 @@ class LibraryTableStore: ( document.id, document.user, int(document.time * 1000), document.kind, document.title, document.comments, - metadata, document.tags, object_id + metadata, document.tags, object_id, + parent_id, document_type ) ) @@ -349,9 +463,58 @@ class LibraryTableStore: p=tuple_to_term(m[2], m[3]), o=tuple_to_term(m[4], m[5]) ) - for m in row[5] + for m in (row[5] or []) ], tags = row[6] if row[6] else [], + parent_id = row[8] if row[8] else "", + document_type = row[9] if row[9] else "source", + ) + for row in resp + ] + + logger.debug("Done") + + return lst + + async def list_children(self, parent_id): + """List all child documents for a given parent document ID.""" + + logger.debug(f"List children for parent {parent_id}") + + while True: + + try: + + resp = self.cassandra.execute( + self.list_children_stmt, + (parent_id,) + ) + + break + + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + lst = [ + DocumentMetadata( + id = row[0], + user = row[1], + time = int(time.mktime(row[2].timetuple())), + kind = row[3], + title = row[4], + comments = row[5], + metadata = [ + Triple( + s=tuple_to_term(m[0], m[1]), + p=tuple_to_term(m[2], m[3]), + o=tuple_to_term(m[4], m[5]) + ) + for m in (row[6] or []) + ], + tags = row[7] if row[7] else [], + parent_id = row[9] if row[9] else "", + document_type = row[10] if row[10] else "source", ) for row in resp ] @@ -394,9 +557,11 @@ class LibraryTableStore: p=tuple_to_term(m[2], m[3]), o=tuple_to_term(m[4], m[5]) ) - for m in row[4] + for m in (row[4] or []) ], tags = row[5] if row[5] else [], + parent_id = row[7] if row[7] else "", + document_type = row[8] if row[8] else "source", ) logger.debug("Done") @@ -532,3 +697,152 @@ class LibraryTableStore: logger.debug("Done") return lst + + # Upload session methods + + async def create_upload_session( + self, + upload_id, + user, + document_id, + document_metadata, + s3_upload_id, + object_id, + total_size, + chunk_size, + total_chunks, + ): + """Create a new upload session for chunked upload.""" + + logger.info(f"Creating upload session {upload_id}") + + now = int(time.time() * 1000) + + while True: + try: + self.cassandra.execute( + self.insert_upload_session_stmt, + ( + upload_id, user, document_id, document_metadata, + s3_upload_id, object_id, total_size, chunk_size, + total_chunks, {}, now, now + ) + ) + break + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + logger.debug("Upload session created") + + async def get_upload_session(self, upload_id): + """Get an upload session by ID.""" + + logger.debug(f"Get upload session {upload_id}") + + while True: + try: + resp = self.cassandra.execute( + self.get_upload_session_stmt, + (upload_id,) + ) + break + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + for row in resp: + session = { + "upload_id": row[0], + "user": row[1], + "document_id": row[2], + "document_metadata": row[3], + "s3_upload_id": row[4], + "object_id": row[5], + "total_size": row[6], + "chunk_size": row[7], + "total_chunks": row[8], + "chunks_received": row[9] if row[9] else {}, + "created_at": row[10], + "updated_at": row[11], + } + logger.debug("Done") + return session + + return None + + async def update_upload_session_chunk(self, upload_id, chunk_index, etag): + """Record a successfully uploaded chunk.""" + + logger.debug(f"Update upload session {upload_id} chunk {chunk_index}") + + now = int(time.time() * 1000) + + while True: + try: + self.cassandra.execute( + self.update_upload_session_chunk_stmt, + ( + {chunk_index: etag}, + now, + upload_id + ) + ) + break + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + logger.debug("Chunk recorded") + + async def delete_upload_session(self, upload_id): + """Delete an upload session.""" + + logger.info(f"Deleting upload session {upload_id}") + + while True: + try: + self.cassandra.execute( + self.delete_upload_session_stmt, + (upload_id,) + ) + break + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + logger.debug("Upload session deleted") + + async def list_upload_sessions(self, user): + """List all upload sessions for a user.""" + + logger.debug(f"List upload sessions for {user}") + + while True: + try: + resp = self.cassandra.execute( + self.list_upload_sessions_stmt, + (user,) + ) + break + except Exception as e: + logger.error("Exception occurred", exc_info=True) + raise e + + sessions = [] + for row in resp: + chunks_received = row[6] if row[6] else {} + sessions.append({ + "upload_id": row[0], + "document_id": row[1], + "document_metadata": row[2], + "total_size": row[3], + "chunk_size": row[4], + "total_chunks": row[5], + "chunks_received": len(chunks_received), + "created_at": row[7], + "updated_at": row[8], + }) + + logger.debug("Done") + return sessions diff --git a/trustgraph-flow/trustgraph/tool_service/__init__.py b/trustgraph-flow/trustgraph/tool_service/__init__.py new file mode 100644 index 00000000..76d859b9 --- /dev/null +++ b/trustgraph-flow/trustgraph/tool_service/__init__.py @@ -0,0 +1 @@ +# Tool service implementations diff --git a/trustgraph-flow/trustgraph/tool_service/joke/__init__.py b/trustgraph-flow/trustgraph/tool_service/joke/__init__.py new file mode 100644 index 00000000..4322f49f --- /dev/null +++ b/trustgraph-flow/trustgraph/tool_service/joke/__init__.py @@ -0,0 +1,2 @@ +# Joke tool service +from .service import run diff --git a/trustgraph-flow/trustgraph/tool_service/joke/service.py b/trustgraph-flow/trustgraph/tool_service/joke/service.py new file mode 100644 index 00000000..d9b7cde0 --- /dev/null +++ b/trustgraph-flow/trustgraph/tool_service/joke/service.py @@ -0,0 +1,204 @@ +""" +Joke Tool Service - An example dynamic tool service. + +This service demonstrates the tool service integration by: +- Using the 'user' field to personalize responses +- Using config params (style) to customize joke style +- Using arguments (topic) to generate topic-specific jokes + +Example tool-service config: +{ + "id": "joke-service", + "topic": "joke", + "config-params": [ + {"name": "style", "required": false} + ] +} + +Example tool config: +{ + "type": "tool-service", + "name": "tell-joke", + "description": "Tell a joke on a given topic", + "service": "joke-service", + "style": "pun", + "arguments": [ + { + "name": "topic", + "type": "string", + "description": "The topic for the joke (e.g., programming, animals, food)" + } + ] +} +""" + +import random +import logging + +from ... base import DynamicToolService + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "joke-service" +default_topic = "joke" + +# Joke database organized by topic and style +JOKES = { + "programming": { + "pun": [ + "Why do programmers prefer dark mode? Because light attracts bugs!", + "Why do Java developers wear glasses? Because they can't C#!", + "A SQL query walks into a bar, walks up to two tables and asks... 'Can I join you?'", + "Why was the JavaScript developer sad? Because he didn't Node how to Express himself!", + ], + "dad-joke": [ + "I told my computer I needed a break, and now it won't stop sending me Kit-Kat ads.", + "My son asked me to explain what a linked list is. I said 'I'll tell you, and then I'll tell you again, and again...'", + "I asked my computer for a joke about UDP. I'm not sure if it got it.", + ], + "one-liner": [ + "There are only 10 types of people: those who understand binary and those who don't.", + "A programmer's wife tells him: 'Go to the store and get a loaf of bread. If they have eggs, get a dozen.' He returns with 12 loaves.", + "99 little bugs in the code, 99 little bugs. Take one down, patch it around, 127 little bugs in the code.", + ], + }, + "llama": { + "pun": [ + "Why did the llama get a ticket? Because he was caught spitting in a no-spitting zone!", + "What do you call a llama who's a great musician? A llama del Rey!", + "Why did the llama cross the road? To prove he wasn't a chicken!", + ], + "dad-joke": [ + "What did the llama say when he got kicked out of the zoo? 'Alpaca my bags!'", + "Why don't llamas ever get lost? Because they always know the way to the Andes!", + "What do you call a llama with no legs? A woolly rug!", + ], + "one-liner": [ + "Llamas are great at meditation. They're always saying 'Dalai Llama.'", + "I asked a llama for directions. He said 'No probllama!'", + "Never trust a llama. They're always up to something woolly.", + ], + }, + "animals": { + "pun": [ + "What do you call a fish without eyes? A fsh!", + "Why don't scientists trust atoms? Because they make up everything... just like that cat who blamed the dog!", + "What do you call a bear with no teeth? A gummy bear!", + ], + "dad-joke": [ + "I tried to catch some fog earlier. I mist. My dog wasn't impressed either.", + "What do you call a dog that does magic tricks? A Labracadabrador!", + "Why do cows wear bells? Because their horns don't work!", + ], + "one-liner": [ + "I'm reading a book about anti-gravity. It's impossible to put down, unlike my cat.", + "A horse walks into a bar. The bartender asks 'Why the long face?'", + "What's orange and sounds like a parrot? A carrot!", + ], + }, + "food": { + "pun": [ + "I'm on a seafood diet. I see food and I eat it!", + "Why did the tomato turn red? Because it saw the salad dressing!", + "What do you call cheese that isn't yours? Nacho cheese!", + ], + "dad-joke": [ + "I used to hate facial hair, but then it grew on me. Speaking of growing, have you tried my garden salad?", + "Why don't eggs tell jokes? They'd crack each other up!", + "I told my wife she was drawing her eyebrows too high. She looked surprised, then made me a sandwich.", + ], + "one-liner": [ + "I'm reading a book about submarines and sandwiches. It's a sub-genre.", + "Broken puppets for sale. No strings attached. Also, free spaghetti!", + "I ordered a chicken and an egg online. I'll let you know which comes first.", + ], + }, + "default": { + "pun": [ + "Time flies like an arrow. Fruit flies like a banana!", + "I used to be a banker, but I lost interest.", + "I'm reading a book on the history of glue. I can't put it down!", + ], + "dad-joke": [ + "I'm afraid for the calendar. Its days are numbered.", + "I only know 25 letters of the alphabet. I don't know y.", + "Did you hear about the claustrophobic astronaut? He just needed a little space.", + ], + "one-liner": [ + "I told my wife she was drawing her eyebrows too high. She looked surprised.", + "I'm not lazy, I'm on energy-saving mode.", + "Parallel lines have so much in common. It's a shame they'll never meet.", + ], + }, +} + + +class Processor(DynamicToolService): + """ + Joke tool service that demonstrates the tool service integration. + """ + + def __init__(self, **params): + super(Processor, self).__init__(**params) + logger.info("Joke service initialized") + + async def invoke(self, user, config, arguments): + """ + Generate a joke based on the topic and style. + + Args: + user: The user requesting the joke + config: Config values including 'style' (pun, dad-joke, one-liner) + arguments: Arguments including 'topic' (programming, animals, food) + + Returns: + A personalized joke string + """ + # Get style from config (default: random) + style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"])) + + # Get topic from arguments (default: random) + topic = arguments.get("topic", "").lower() + + # Map topic to our categories + if "program" in topic or "code" in topic or "computer" in topic or "software" in topic: + category = "programming" + elif "llama" in topic: + category = "llama" + elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic: + category = "animals" + elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic: + category = "food" + else: + category = "default" + + # Normalize style + if style not in ["pun", "dad-joke", "one-liner"]: + style = random.choice(["pun", "dad-joke", "one-liner"]) + + # Get jokes for this category and style + jokes = JOKES.get(category, JOKES["default"]).get(style, JOKES["default"]["pun"]) + + # Pick a random joke + joke = random.choice(jokes) + + # Personalize the response + response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + + logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}") + + return response + + @staticmethod + def add_args(parser): + DynamicToolService.add_args(parser) + # Override the topic default for this service + for action in parser._actions: + if '--topic' in action.option_strings: + action.default = default_topic + break + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index 2c84d21c..e551ed5d 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -443,13 +443,22 @@ class McpServer: gen = manager.request("graph-rag", request_data, flow_id) + text_chunks = [] async for response in gen: + # Handle new message format with message_type + message_type = response.get("message_type", "chunk") - # Extract vectors from response - text = response.get("response", "") - break - - return GraphRagResponse(response=text) + # Only collect text from chunk messages + if message_type == "chunk": + chunk_text = response.get("response", "") + if chunk_text: + text_chunks.append(chunk_text) + + # Check if session is complete + if response.get("end_of_session"): + break + + return GraphRagResponse(response="".join(text_chunks)) async def agent( self, diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index d089180a..b94a954b 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 48f92207..82da9155 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.0,<2.1", + "trustgraph-base>=2.1,<2.2", "pulsar-client", "google-genai", "google-api-core",