diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index d02df438..48154284 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=1.8.999 + run: make update-package-versions VERSION=2.0.999 - name: Setup environment run: python3 -m venv env diff --git a/Makefile b/Makefile index c7b4797f..13679d0d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ VERSION=0.0.0 DOCKER=podman -all: container +all: containers # Not used wheels: @@ -49,7 +49,9 @@ update-package-versions: echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py -container: update-package-versions +FORCE: + +containers: FORCE ${DOCKER} build -f containers/Containerfile.base \ -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ @@ -70,8 +72,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - ${DOCKER} build -f containers/Containerfile.vertexai \ - -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . +# ${DOCKER} build -f containers/Containerfile.vertexai \ +# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ # -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ diff --git a/containers/Containerfile.flow b/containers/Containerfile.flow index 8b20050d..c6bc7e50 100644 --- a/containers/Containerfile.flow +++ b/containers/Containerfile.flow @@ -16,7 +16,7 @@ RUN dnf install -y python3.13 && \ dnf clean all RUN pip3 install --no-cache-dir \ - anthropic cohere mistralai openai google-generativeai \ + anthropic cohere mistralai openai \ ollama \ langchain==0.3.25 langchain-core==0.3.60 \ langchain-text-splitters==0.3.8 \ diff --git a/docs/tech-specs/entity-centric-graph.md b/docs/tech-specs/entity-centric-graph.md new file mode 100644 index 00000000..aa695811 --- /dev/null +++ b/docs/tech-specs/entity-centric-graph.md @@ -0,0 +1,260 @@ +# Entity-Centric Knowledge Graph Storage on Cassandra + +## Overview + +This document describes a storage model for RDF-style knowledge graphs on Apache Cassandra. The model uses an **entity-centric** approach where every entity knows every quad it participates in and the role it plays. This replaces a traditional multi-table SPO permutation approach with just two tables. + +## Background and Motivation + +### The Traditional Approach + +A standard RDF quad store on Cassandra requires multiple denormalised tables to cover query patterns — typically 6 or more tables representing different permutations of Subject, Predicate, Object, and Dataset (SPOD). Each quad is written to every table, resulting in significant write amplification, operational overhead, and schema complexity. + +Additionally, label resolution (fetching human-readable names for entities) requires separate round-trip queries, which is particularly costly in AI and GraphRAG use cases where labels are essential for LLM context. + +### The Entity-Centric Insight + +Every quad `(D, S, P, O)` involves up to 4 entities. By writing a row for each entity's participation in the quad, we guarantee that **any query with at least one known element will hit a partition key**. This covers all 16 query patterns with a single data table. + +Key benefits: + +- **2 tables** instead of 7+ +- **4 writes per quad** instead of 6+ +- **Label resolution for free** — an entity's labels are co-located with its relationships, naturally warming the application cache +- **All 16 query patterns** served by single-partition reads +- **Simpler operations** — one data table to tune, compact, and repair + +## Schema + +### Table 1: quads_by_entity + +The primary data table. Every entity has a partition containing all quads it participates in. Named to reflect the query pattern (lookup by entity). + +```sql +CREATE TABLE quads_by_entity ( + collection text, -- Collection/tenant scope (always specified) + entity text, -- The entity this row is about + role text, -- 'S', 'P', 'O', 'G' — how this entity participates + p text, -- Predicate of the quad + otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification) + s text, -- Subject of the quad + o text, -- Object of the quad + d text, -- Dataset/graph of the quad + dtype text, -- XSD datatype (when otype = 'L'), e.g. 'xsd:string' + lang text, -- Language tag (when otype = 'L'), e.g. 'en', 'fr' + PRIMARY KEY ((collection, entity), role, p, otype, s, o, d) +); +``` + +**Partition key**: `(collection, entity)` — scoped to collection, one partition per entity. + +**Clustering column order rationale**: + +1. **role** — most queries start with "where is this entity a subject/object" +2. **p** — next most common filter, "give me all `knows` relationships" +3. **otype** — enables filtering by URI-valued vs literal-valued relationships +4. **s, o, d** — remaining columns for uniqueness + +### Table 2: quads_by_collection + +Supports collection-level queries and deletion. Provides a manifest of all quads belonging to a collection. Named to reflect the query pattern (lookup by collection). + +```sql +CREATE TABLE quads_by_collection ( + collection text, + d text, -- Dataset/graph of the quad + s text, -- Subject of the quad + p text, -- Predicate of the quad + o text, -- Object of the quad + otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification) + dtype text, -- XSD datatype (when otype = 'L') + lang text, -- Language tag (when otype = 'L') + PRIMARY KEY (collection, d, s, p, o) +); +``` + +Clustered by dataset first, enabling deletion at either collection or dataset granularity. + +## Write Path + +For each incoming quad `(D, S, P, O)` within a collection `C`, write **4 rows** to `quads_by_entity` and **1 row** to `quads_by_collection`. + +### Example + +Given the quad in collection `tenant1`: + +``` +Dataset: https://example.org/graph1 +Subject: https://example.org/Alice +Predicate: https://example.org/knows +Object: https://example.org/Bob +``` + +Write 4 rows to `quads_by_entity`: + +| collection | entity | role | p | otype | s | o | d | +|---|---|---|---|---|---|---|---| +| tenant1 | https://example.org/graph1 | G | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 | +| tenant1 | https://example.org/Alice | S | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 | +| tenant1 | https://example.org/knows | P | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 | +| tenant1 | https://example.org/Bob | O | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 | + +Write 1 row to `quads_by_collection`: + +| collection | d | s | p | o | otype | dtype | lang | +|---|---|---|---|---|---|---|---| +| tenant1 | https://example.org/graph1 | https://example.org/Alice | https://example.org/knows | https://example.org/Bob | U | | | + +### Literal Example + +For a label triple: + +``` +Dataset: https://example.org/graph1 +Subject: https://example.org/Alice +Predicate: http://www.w3.org/2000/01/rdf-schema#label +Object: "Alice Smith" (lang: en) +``` + +The `otype` is `'L'`, `dtype` is `'xsd:string'`, and `lang` is `'en'`. The literal value `"Alice Smith"` is stored in `o`. Only 3 rows are needed in `quads_by_entity` — no row is written for the literal as entity, since literals are not independently queryable entities. + +## Query Patterns + +### All 16 DSPO Patterns + +In the table below, "Perfect prefix" means the query uses a contiguous prefix of the clustering columns. "Partition scan + filter" means Cassandra reads a slice of one partition and filters in memory — still efficient, just not a pure prefix match. + +| # | Known | Lookup entity | Clustering prefix | Efficiency | +|---|---|---|---|---| +| 1 | D,S,P,O | entity=S, role='S', p=P | Full match | Perfect prefix | +| 2 | D,S,P,? | entity=S, role='S', p=P | Filter on D | Partition scan + filter | +| 3 | D,S,?,O | entity=S, role='S' | Filter on D, O | Partition scan + filter | +| 4 | D,?,P,O | entity=O, role='O', p=P | Filter on D | Partition scan + filter | +| 5 | ?,S,P,O | entity=S, role='S', p=P | Filter on O | Partition scan + filter | +| 6 | D,S,?,? | entity=S, role='S' | Filter on D | Partition scan + filter | +| 7 | D,?,P,? | entity=P, role='P' | Filter on D | Partition scan + filter | +| 8 | D,?,?,O | entity=O, role='O' | Filter on D | Partition scan + filter | +| 9 | ?,S,P,? | entity=S, role='S', p=P | — | **Perfect prefix** | +| 10 | ?,S,?,O | entity=S, role='S' | Filter on O | Partition scan + filter | +| 11 | ?,?,P,O | entity=O, role='O', p=P | — | **Perfect prefix** | +| 12 | D,?,?,? | entity=D, role='G' | — | **Perfect prefix** | +| 13 | ?,S,?,? | entity=S, role='S' | — | **Perfect prefix** | +| 14 | ?,?,P,? | entity=P, role='P' | — | **Perfect prefix** | +| 15 | ?,?,?,O | entity=O, role='O' | — | **Perfect prefix** | +| 16 | ?,?,?,? | — | Full scan | Exploration only | + +**Key result**: 7 of the 15 non-trivial patterns are perfect clustering prefix hits. The remaining 8 are single-partition reads with in-partition filtering. Every query with at least one known element hits a partition key. + +Pattern 16 (?,?,?,?) does not occur in practice since collection is always specified, reducing it to pattern 12. + +### Common Query Examples + +**Everything about an entity:** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'; +``` + +**All outgoing relationships for an entity:** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice' +AND role = 'S'; +``` + +**Specific predicate for an entity:** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice' +AND role = 'S' AND p = 'https://example.org/knows'; +``` + +**Label for an entity (specific language):** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice' +AND role = 'S' AND p = 'http://www.w3.org/2000/01/rdf-schema#label' +AND otype = 'L'; +``` + +Then filter by `lang = 'en'` application-side if needed. + +**Only URI-valued relationships (entity-to-entity links):** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice' +AND role = 'S' AND p = 'https://example.org/knows' AND otype = 'U'; +``` + +**Reverse lookup — what points to this entity:** + +```sql +SELECT * FROM quads_by_entity +WHERE collection = 'tenant1' AND entity = 'https://example.org/Bob' +AND role = 'O'; +``` + +## Label Resolution and Cache Warming + +One of the most significant advantages of the entity-centric model is that **label resolution becomes a free side effect**. + +In the traditional multi-table model, fetching labels requires separate round-trip queries: retrieve triples, identify entity URIs in the results, then fetch `rdfs:label` for each. This N+1 pattern is expensive. + +In the entity-centric model, querying an entity returns **all** its quads — including its labels, types, and other properties. When the application caches query results, labels are pre-warmed before anything asks for them. + +Two usage regimes confirm this works well in practice: + +- **Human-facing queries**: naturally small result sets, labels essential. Entity reads pre-warm the cache. +- **AI/bulk queries**: large result sets with hard limits. Labels either unnecessary or needed only for a curated subset of entities already in cache. + +The theoretical concern of resolving labels for huge result sets (e.g. 30,000 entities) is mitigated by the practical observation that no human or AI consumer usefully processes that many labels. Application-level query limits ensure cache pressure remains manageable. + +## Wide Partitions and Reification + +Reification (RDF-star style statements about statements) creates hub entities — e.g. a source document that supports thousands of extracted facts. This can produce wide partitions. + +Mitigating factors: + +- **Application-level query limits**: all GraphRAG and human-facing queries enforce hard limits, so wide partitions are never fully scanned on the hot read path +- **Cassandra handles partial reads efficiently**: a clustering column scan with an early stop is fast even on large partitions +- **Collection deletion** (the only operation that might traverse full partitions) is an accepted background process + +## Collection Deletion + +Triggered by API call, runs in the background (eventually consistent). + +1. Read `quads_by_collection` for the target collection to get all quads +2. Extract unique entities from the quads (s, p, o, d values) +3. For each unique entity, delete the partition from `quads_by_entity` +4. Delete the rows from `quads_by_collection` + +The `quads_by_collection` table provides the index needed to locate all entity partitions without a full table scan. Partition-level deletes are efficient since `(collection, entity)` is the partition key. + +## Migration Path from Multi-Table Model + +The entity-centric model can coexist with the existing multi-table model during migration: + +1. Deploy `quads_by_entity` and `quads_by_collection` tables alongside existing tables +2. Dual-write new quads to both old and new tables +3. Backfill existing data into the new tables +4. Migrate read paths one query pattern at a time +5. Decommission old tables once all reads are migrated + +## Summary + +| Aspect | Traditional (6-table) | Entity-centric (2-table) | +|---|---|---| +| Tables | 7+ | 2 | +| Writes per quad | 6+ | 5 (4 data + 1 manifest) | +| Label resolution | Separate round trips | Free via cache warming | +| Query patterns | 16 across 6 tables | 16 on 1 table | +| Schema complexity | High | Low | +| Operational overhead | 6 tables to tune/repair | 1 data table | +| Reification support | Additional complexity | Natural fit | +| Object type filtering | Not available | Native (via otype clustering) | + diff --git a/docs/tech-specs/graph-contexts.md b/docs/tech-specs/graph-contexts.md new file mode 100644 index 00000000..54737012 --- /dev/null +++ b/docs/tech-specs/graph-contexts.md @@ -0,0 +1,573 @@ +# Graph Contexts Technical Specification + +## Overview + +This specification describes changes to TrustGraph's core graph primitives to +align with RDF 1.2 and support full RDF Dataset semantics. This is a breaking +change for the 2.x release series. + +### Versioning + +- **2.0**: Early adopter release. Core features available, may not be fully + production-ready. +- **2.1 / 2.2**: Production release. Stability and completeness validated. + +Flexibility on maturity is intentional - early adopters can access new +capabilities before all features are production-hardened. + +## Goals + +The primary goals for this work are to enable metadata about facts/statements: + +- **Temporal information**: Associate facts with time metadata + - When a fact was believed to be true + - When a fact became true + - When a fact was discovered to be false + +- **Provenance/Sources**: Track which sources support a fact + - "This fact was supported by source X" + - Link facts back to their origin documents + +- **Veracity/Trust**: Record assertions about truth + - "Person P asserted this was true" + - "Person Q claims this is false" + - Enable trust scoring and conflict detection + +**Hypothesis**: Reification (RDF-star / quoted triples) is the key mechanism +to achieve these outcomes, as all require making statements about statements. + +## Background + +To express "the fact (Alice knows Bob) was discovered on 2024-01-15" or +"source X supports the claim (Y causes Z)", you need to reference an edge +as a thing you can make statements about. Standard triples don't support this. + +### Current Limitations + +The current `Value` class in `trustgraph-base/trustgraph/schema/core/primitives.py` +can represent: +- URI nodes (`is_uri=True`) +- Literal values (`is_uri=False`) + +The `type` field exists but is not used to represent XSD datatypes. + +## Technical Design + +### RDF Features to Support + +#### Core Features (Related to Reification Goals) + +These features are directly related to the temporal, provenance, and veracity +goals: + +1. **RDF 1.2 Quoted Triples (RDF-star)** + - Edges that point at other edges + - A Triple can appear as the subject or object of another Triple + - Enables statements about statements (reification) + - Core mechanism for annotating individual facts + +2. **RDF Dataset / Named Graphs** + - Support for multiple named graphs within a dataset + - Each graph identified by an IRI + - Moves from triples (s, p, o) to quads (s, p, o, g) + - Includes a default graph plus zero or more named graphs + - The graph IRI can be a subject in statements, e.g.: + ``` + "2024-01-15" + "high" + ``` + - Note: Named graphs are a separate feature from reification. They have + uses beyond statement annotation (partitioning, access control, dataset + organization) and should be treated as a distinct capability. + +3. **Blank Nodes** (Limited Support) + - Anonymous nodes without a global URI + - Supported for compatibility when loading external RDF data + - **Limited status**: No guarantees about stable identity after loading + - Find them via wildcard queries (match by connections, not by ID) + - Not a first-class feature - don't rely on precise blank node handling + +#### Opportunistic Fixes (2.0 Breaking Change) + +These features are not directly related to the reification goals but are +valuable improvements to include while making breaking changes: + +4. **Literal Datatypes** + - Properly use the `type` field for XSD datatypes + - Examples: xsd:string, xsd:integer, xsd:dateTime, etc. + - Fixes current limitation: cannot represent dates or integers properly + +5. **Language Tags** + - Support for language attributes on string literals (@en, @fr, etc.) + - Note: A literal has either a language tag OR a datatype, not both + (except for rdf:langString) + - Important for AI/multilingual use cases + +### Data Models + +#### Term (rename from Value) + +The `Value` class will be renamed to `Term` to better reflect RDF terminology. +This rename serves two purposes: +1. Aligns naming with RDF concepts (a "Term" can be an IRI, literal, blank + node, or quoted triple - not just a "value") +2. Forces code review at the breaking change interface - any code still + referencing `Value` is visibly broken and needs updating + +A Term can represent: + +- **IRI/URI** - A named node/resource +- **Blank Node** - An anonymous node with local scope +- **Literal** - A data value with either: + - A datatype (XSD type), OR + - A language tag +- **Quoted Triple** - A triple used as a term (RDF 1.2) + +##### Chosen Approach: Single Class with Type Discriminator + +Serialization requirements drive the structure - a type discriminator is needed +in the wire format regardless of the Python representation. A single class with +a type field is the natural fit and aligns with the current `Value` pattern. + +Single-character type codes provide compact serialization: + +```python +from dataclasses import dataclass + +# Term type constants +IRI = "i" # IRI/URI node +BLANK = "b" # Blank node +LITERAL = "l" # Literal value +TRIPLE = "t" # Quoted triple (RDF-star) + +@dataclass +class Term: + type: str = "" # One of: IRI, BLANK, LITERAL, TRIPLE + + # For IRI terms (type == IRI) + iri: str = "" + + # For blank nodes (type == BLANK) + id: str = "" + + # For literals (type == LITERAL) + value: str = "" + datatype: str = "" # XSD datatype URI (mutually exclusive with language) + language: str = "" # Language tag (mutually exclusive with datatype) + + # For quoted triples (type == TRIPLE) + triple: "Triple | None" = None +``` + +Usage examples: + +```python +# IRI term +node = Term(type=IRI, iri="http://example.org/Alice") + +# Literal with datatype +age = Term(type=LITERAL, value="42", datatype="xsd:integer") + +# Literal with language tag +label = Term(type=LITERAL, value="Hello", language="en") + +# Blank node +anon = Term(type=BLANK, id="_:b1") + +# Quoted triple (statement about a statement) +inner = Triple( + s=Term(type=IRI, iri="http://example.org/Alice"), + p=Term(type=IRI, iri="http://example.org/knows"), + o=Term(type=IRI, iri="http://example.org/Bob"), +) +reified = Term(type=TRIPLE, triple=inner) +``` + +##### Alternatives Considered + +**Option B: Union of specialized classes** (`Term = IRI | BlankNode | Literal | QuotedTriple`) +- Rejected: Serialization would still need a type discriminator, adding complexity + +**Option C: Base class with subclasses** +- Rejected: Same serialization issue, plus dataclass inheritance quirks + +#### Triple / Quad + +The `Triple` class gains an optional graph field to become a quad: + +```python +@dataclass +class Triple: + s: Term | None = None # Subject + p: Term | None = None # Predicate + o: Term | None = None # Object + g: str | None = None # Graph name (IRI), None = default graph +``` + +Design decisions: +- **Field name**: `g` for consistency with `s`, `p`, `o` +- **Optional**: `None` means the default graph (unnamed) +- **Type**: Plain string (IRI) rather than Term + - Graph names are always IRIs + - Blank nodes as graph names ruled out (too confusing) + - No need for the full Term machinery + +Note: The class name stays `Triple` even though it's technically a quad now. +This avoids churn and "triple" is still the common terminology for the s/p/o +portion. The graph context is metadata about where the triple lives. + +### Candidate Query Patterns + +The current query engine accepts combinations of S, P, O terms. With quoted +triples, a triple itself becomes a valid term in those positions. Below are +candidate query patterns that support the original goals. + +#### Graph Parameter Semantics + +Following SPARQL conventions for backward compatibility: + +- **`g` omitted / None**: Query the default graph only +- **`g` = specific IRI**: Query that named graph only +- **`g` = wildcard / `*`**: Query across all graphs (equivalent to SPARQL + `GRAPH ?g { ... }`) + +This keeps simple queries simple and makes named graph queries opt-in. + +Cross-graph queries (g=wildcard) are fully supported. The Cassandra schema +includes dedicated tables (SPOG, POSG, OSPG) where g is a clustering column +rather than a partition key, enabling efficient queries across all graphs. + +#### Temporal Queries + +**Find all facts discovered after a given date:** +``` +S: ? # any quoted triple +P: +O: > "2024-01-15"^^xsd:date # date comparison +``` + +**Find when a specific fact was believed true:** +``` +S: << >> # quoted triple as subject +P: +O: ? # returns the date +``` + +**Find facts that became false:** +``` +S: ? # any quoted triple +P: +O: ? # has any value (exists) +``` + +#### Provenance Queries + +**Find all facts supported by a specific source:** +``` +S: ? # any quoted triple +P: +O: +``` + +**Find which sources support a specific fact:** +``` +S: << >> # quoted triple as subject +P: +O: ? # returns source IRIs +``` + +#### Veracity Queries + +**Find assertions a person marked as true:** +``` +S: ? # any quoted triple +P: +O: +``` + +**Find conflicting assertions (same fact, different veracity):** +``` +# First query: facts asserted true +S: ? +P: +O: ? + +# Second query: facts asserted false +S: ? +P: +O: ? + +# Application logic: find intersection of subjects +``` + +**Find facts with trust score below threshold:** +``` +S: ? # any quoted triple +P: +O: < 0.5 # numeric comparison +``` + +### Architecture + +Significant changes required across multiple components: + +#### This Repository (trustgraph) + +- **Schema primitives** (`trustgraph-base/trustgraph/schema/core/primitives.py`) + - Value → Term rename + - New Term structure with type discriminator + - Triple gains `g` field for graph context + +- **Message translators** (`trustgraph-base/trustgraph/messaging/translators/`) + - Update for new Term/Triple structures + - Serialization/deserialization for new fields + +- **Gateway components** + - Handle new Term and quad structures + +- **Knowledge cores** + - Core changes to support quads and reification + +- **Knowledge manager** + - Schema changes propagate here + +- **Storage layers** + - Cassandra: Schema redesign (see Implementation Details) + - Other backends: Deferred to later phases + +- **Command-line utilities** + - Update for new data structures + +- **REST API documentation** + - OpenAPI spec updates + +#### External Repositories + +- **Python API** (this repo) + - Client library updates for new structures + +- **TypeScript APIs** (separate repo) + - Client library updates + +- **Workbench** (separate repo) + - Significant state management changes + +### APIs + +#### REST API + +- Documented in OpenAPI spec +- Will need updates for new Term/Triple structures +- New endpoints may be needed for graph context operations + +#### Python API (this repo) + +- Client library changes to match new primitives +- Breaking changes to Term (was Value) and Triple + +#### TypeScript API (separate repo) + +- Parallel changes to Python API +- Separate release coordination + +#### Workbench (separate repo) + +- Significant state management changes +- UI updates for graph context features + +### Implementation Details + +#### Phased Storage Implementation + +Multiple graph store backends exist (Cassandra, Neo4j, etc.). Implementation +will proceed in phases: + +1. **Phase 1: Cassandra** + - Start with the home-grown Cassandra store + - Full control over the storage layer enables rapid iteration + - Schema will be redesigned from scratch for quads + reification + - Validate the data model and query patterns against real use cases + +#### Cassandra Schema Design + +Cassandra requires multiple tables to support different query access patterns +(each table efficiently queries by its partition key + clustering columns). + +##### Query Patterns + +With quads (g, s, p, o), each position can be specified or wildcard, giving +16 possible query patterns: + +| # | g | s | p | o | Description | +|---|---|---|---|---|-------------| +| 1 | ? | ? | ? | ? | All quads | +| 2 | ? | ? | ? | o | By object | +| 3 | ? | ? | p | ? | By predicate | +| 4 | ? | ? | p | o | By predicate + object | +| 5 | ? | s | ? | ? | By subject | +| 6 | ? | s | ? | o | By subject + object | +| 7 | ? | s | p | ? | By subject + predicate | +| 8 | ? | s | p | o | Full triple (which graphs?) | +| 9 | g | ? | ? | ? | By graph | +| 10 | g | ? | ? | o | By graph + object | +| 11 | g | ? | p | ? | By graph + predicate | +| 12 | g | ? | p | o | By graph + predicate + object | +| 13 | g | s | ? | ? | By graph + subject | +| 14 | g | s | ? | o | By graph + subject + object | +| 15 | g | s | p | ? | By graph + subject + predicate | +| 16 | g | s | p | o | Exact quad | + +##### Table Design + +Cassandra constraint: You can only efficiently query by partition key, then +filter on clustering columns left-to-right. For g-wildcard queries, g must be +a clustering column. For g-specified queries, g in the partition key is more +efficient. + +**Two table families needed:** + +**Family A: g-wildcard queries** (g in clustering columns) + +| Table | Partition | Clustering | Supports patterns | +|-------|-----------|------------|-------------------| +| SPOG | (user, collection, s) | p, o, g | 5, 7, 8 | +| POSG | (user, collection, p) | o, s, g | 3, 4 | +| OSPG | (user, collection, o) | s, p, g | 2, 6 | + +**Family B: g-specified queries** (g in partition key) + +| Table | Partition | Clustering | Supports patterns | +|-------|-----------|------------|-------------------| +| GSPO | (user, collection, g, s) | p, o | 9, 13, 15, 16 | +| GPOS | (user, collection, g, p) | o, s | 11, 12 | +| GOSP | (user, collection, g, o) | s, p | 10, 14 | + +**Collection table** (for iteration and bulk deletion) + +| Table | Partition | Clustering | Purpose | +|-------|-----------|------------|---------| +| COLL | (user, collection) | g, s, p, o | Enumerate all quads in collection | + +##### Write and Delete Paths + +**Write path**: Insert into all 7 tables. + +**Delete collection path**: +1. Iterate COLL table for `(user, collection)` +2. For each quad, delete from all 6 query tables +3. Delete from COLL table (or range delete) + +**Delete single quad path**: Delete from all 7 tables directly. + +##### Storage Cost + +Each quad is stored 7 times. This is the cost of flexible querying combined +with efficient collection deletion. + +##### Quoted Triples in Storage + +Subject or object can be a triple itself. Options: + +**Option A: Serialize quoted triples to canonical string** +``` +S: "<>" +P: http://ex/discoveredOn +O: "2024-01-15" +G: null +``` +- Store quoted triple as serialized string in S or O columns +- Query by exact match on serialized form +- Pro: Simple, fits existing index patterns +- Con: Can't query "find triples where quoted subject's predicate is X" + +**Option B: Triple IDs / Hashes** +``` +Triple table: + id: hash(s,p,o,g) + s, p, o, g: ... + +Metadata table: + subject_triple_id: + p: http://ex/discoveredOn + o: "2024-01-15" +``` +- Assign each triple an ID (hash of components) +- Reification metadata references triples by ID +- Pro: Clean separation, can index triple IDs +- Con: Requires computing/managing triple identity, two-phase lookups + +**Recommendation**: Start with Option A (serialized strings) for simplicity. +Option B may be needed if advanced query patterns over quoted triple +components are required. + +2. **Phase 2+: Other Backends** + - Neo4j and other stores implemented in subsequent stages + - Lessons learned from Cassandra inform these implementations + +This approach de-risks the design by validating on a fully-controlled backend +before committing to implementations across all stores. + +#### Value → Term Rename + +The `Value` class will be renamed to `Term`. This affects ~78 files across +the codebase. The rename acts as a forcing function: any code still using +`Value` is immediately identifiable as needing review/update for 2.0 +compatibility. + +## Security Considerations + +Named graphs are not a security feature. Users and collections remain the +security boundaries. Named graphs are purely for data organization and +reification support. + +## Performance Considerations + +- Quoted triples add nesting depth - may impact query performance +- Named graph indexing strategies needed for efficient graph-scoped queries +- Cassandra schema design will need to accommodate quad storage efficiently + +### Vector Store Boundary + +Vector stores always reference IRIs only: +- Never edges (quoted triples) +- Never literal values +- Never blank nodes + +This keeps the vector store simple - it handles semantic similarity of named +entities. The graph structure handles relationships, reification, and metadata. +Quoted triples and named graphs don't complicate vector operations. + +## Testing Strategy + +Use existing test strategy. As this is a breaking change, extensive focus on +the end-to-end test suite to validate the new structures work correctly across +all components. + +## Migration Plan + +- 2.0 is a breaking release; no backward compatibility required +- Existing data may need migration to new schema (TBD based on final design) +- Consider migration tooling for converting existing triples + +## Open Questions + +- **Blank nodes**: Limited support confirmed. May need to decide on + skolemization strategy (generate IRIs on load, or preserve blank node IDs). +- **Query syntax**: What is the concrete syntax for specifying quoted triples + in queries? Need to define the query API. +- ~~**Predicate vocabulary**~~: Resolved. Any valid RDF predicates permitted, + including custom user-defined. Minimal assumptions about RDF validity. + Very few locked-in values (e.g., `rdfs:label` used in some places). + Strategy: avoid locking anything in unless absolutely necessary. +- ~~**Vector store impact**~~: Resolved. Vector stores always point to IRIs + only - never edges, literals, or blank nodes. Quoted triples and + reification don't affect the vector store. +- ~~**Named graph semantics**~~: Resolved. Queries default to the default + graph (matches SPARQL behavior, backward compatible). Explicit graph + parameter required to query named graphs or all graphs. + +## References + +- [RDF 1.2 Concepts](https://www.w3.org/TR/rdf12-concepts/) +- [RDF-star and SPARQL-star](https://w3c.github.io/rdf-star/) +- [RDF Dataset](https://www.w3.org/TR/rdf11-concepts/#section-dataset) diff --git a/docs/tech-specs/jsonl-prompt-output.md b/docs/tech-specs/jsonl-prompt-output.md new file mode 100644 index 00000000..d8872fd4 --- /dev/null +++ b/docs/tech-specs/jsonl-prompt-output.md @@ -0,0 +1,455 @@ +# JSONL Prompt Output Technical Specification + +## Overview + +This specification describes the implementation of JSONL (JSON Lines) output +format for prompt responses in TrustGraph. JSONL enables truncation-resilient +extraction of structured data from LLM responses, addressing critical issues +with JSON array outputs being corrupted when LLM responses hit output token +limits. + +This implementation supports the following use cases: + +1. **Truncation-Resilient Extraction**: Extract valid partial results even when + LLM output is truncated mid-response +2. **Large-Scale Extraction**: Handle extraction of many items without risk of + complete failure due to token limits +3. **Mixed-Type Extraction**: Support extraction of multiple entity types + (definitions, relationships, entities, attributes) in a single prompt +4. **Streaming-Compatible Output**: Enable future streaming/incremental + processing of extraction results + +## Goals + +- **Backward Compatibility**: Existing prompts using `response-type: "text"` and + `response-type: "json"` continue to work without modification +- **Truncation Resilience**: Partial LLM outputs yield partial valid results + rather than complete failure +- **Schema Validation**: Support JSON Schema validation for individual objects +- **Discriminated Unions**: Support mixed-type outputs using a `type` field + discriminator +- **Minimal API Changes**: Extend existing prompt configuration with new + response type and schema key + +## Background + +### Current Architecture + +The prompt service supports two response types: + +1. `response-type: "text"` - Raw text response returned as-is +2. `response-type: "json"` - JSON parsed from response, validated against + optional `schema` + +Current implementation in `trustgraph-flow/trustgraph/template/prompt_manager.py`: + +```python +class Prompt: + def __init__(self, template, response_type = "text", terms=None, schema=None): + self.template = template + self.response_type = response_type + self.terms = terms + self.schema = schema +``` + +### Current Limitations + +When extraction prompts request output as JSON arrays (`[{...}, {...}, ...]`): + +- **Truncation corruption**: If the LLM hits output token limits mid-array, the + entire response becomes invalid JSON and cannot be parsed +- **All-or-nothing parsing**: Must receive complete output before parsing +- **No partial results**: A truncated response yields zero usable data +- **Unreliable for large extractions**: More extracted items = higher failure risk + +This specification addresses these limitations by introducing JSONL format for +extraction prompts, where each extracted item is a complete JSON object on its +own line. + +## Technical Design + +### Response Type Extension + +Add a new response type `"jsonl"` alongside existing `"text"` and `"json"` types. + +#### Configuration Changes + +**New response type value:** + +``` +"response-type": "jsonl" +``` + +**Schema interpretation:** + +The existing `"schema"` key is used for both `"json"` and `"jsonl"` response +types. The interpretation depends on the response type: + +- `"json"`: Schema describes the entire response (typically an array or object) +- `"jsonl"`: Schema describes each individual line/object + +```json +{ + "response-type": "jsonl", + "schema": { + "type": "object", + "properties": { + "entity": { "type": "string" }, + "definition": { "type": "string" } + }, + "required": ["entity", "definition"] + } +} +``` + +This avoids changes to prompt configuration tooling and editors. + +### JSONL Format Specification + +#### Simple Extraction + +For prompts extracting a single type of object (definitions, relationships, +topics, rows), the output is one JSON object per line with no wrapper: + +**Prompt output format:** +``` +{"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"} +{"entity": "chlorophyll", "definition": "Green pigment in plants"} +{"entity": "mitochondria", "definition": "Powerhouse of the cell"} +``` + +**Contrast with previous JSON array format:** +```json +[ + {"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"}, + {"entity": "chlorophyll", "definition": "Green pigment in plants"}, + {"entity": "mitochondria", "definition": "Powerhouse of the cell"} +] +``` + +If the LLM truncates after line 2, the JSON array format yields invalid JSON, +while JSONL yields two valid objects. + +#### Mixed-Type Extraction (Discriminated Unions) + +For prompts extracting multiple types of objects (e.g., both definitions and +relationships, or entities, relationships, and attributes), use a `"type"` +field as discriminator: + +**Prompt output format:** +``` +{"type": "definition", "entity": "DNA", "definition": "Molecule carrying genetic instructions"} +{"type": "relationship", "subject": "DNA", "predicate": "located_in", "object": "cell nucleus", "object-entity": true} +{"type": "definition", "entity": "RNA", "definition": "Molecule that carries genetic information"} +{"type": "relationship", "subject": "RNA", "predicate": "transcribed_from", "object": "DNA", "object-entity": true} +``` + +**Schema for discriminated unions uses `oneOf`:** +```json +{ + "response-type": "jsonl", + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { "const": "definition" }, + "entity": { "type": "string" }, + "definition": { "type": "string" } + }, + "required": ["type", "entity", "definition"] + }, + { + "type": "object", + "properties": { + "type": { "const": "relationship" }, + "subject": { "type": "string" }, + "predicate": { "type": "string" }, + "object": { "type": "string" }, + "object-entity": { "type": "boolean" } + }, + "required": ["type", "subject", "predicate", "object", "object-entity"] + } + ] + } +} +``` + +#### Ontology Extraction + +For ontology-based extraction with entities, relationships, and attributes: + +**Prompt output format:** +``` +{"type": "entity", "entity": "Cornish pasty", "entity_type": "fo/Recipe"} +{"type": "entity", "entity": "beef", "entity_type": "fo/Food"} +{"type": "relationship", "subject": "Cornish pasty", "subject_type": "fo/Recipe", "relation": "fo/has_ingredient", "object": "beef", "object_type": "fo/Food"} +{"type": "attribute", "entity": "Cornish pasty", "entity_type": "fo/Recipe", "attribute": "fo/serves", "value": "4 people"} +``` + +### Implementation Details + +#### Prompt Class + +The existing `Prompt` class requires no changes. The `schema` field is reused +for JSONL, with its interpretation determined by `response_type`: + +```python +class Prompt: + def __init__(self, template, response_type="text", terms=None, schema=None): + self.template = template + self.response_type = response_type + self.terms = terms + self.schema = schema # Interpretation depends on response_type +``` + +#### PromptManager.load_config + +No changes required - existing configuration loading already handles the +`schema` key. + +#### JSONL Parsing + +Add a new parsing method for JSONL responses: + +```python +def parse_jsonl(self, text): + """ + Parse JSONL response, returning list of valid objects. + + Invalid lines (malformed JSON, empty lines) are skipped with warnings. + This provides truncation resilience - partial output yields partial results. + """ + results = [] + + for line_num, line in enumerate(text.strip().split('\n'), 1): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Skip markdown code fence markers if present + if line.startswith('```'): + continue + + try: + obj = json.loads(line) + results.append(obj) + except json.JSONDecodeError as e: + # Log warning but continue - this provides truncation resilience + logger.warning(f"JSONL parse error on line {line_num}: {e}") + + return results +``` + +#### PromptManager.invoke Changes + +Extend the invoke method to handle the new response type: + +```python +async def invoke(self, id, input, llm): + logger.debug("Invoking prompt template...") + + terms = self.terms | self.prompts[id].terms | input + resp_type = self.prompts[id].response_type + + prompt = { + "system": self.system_template.render(terms), + "prompt": self.render(id, input) + } + + resp = await llm(**prompt) + + if resp_type == "text": + return resp + + if resp_type == "json": + try: + obj = self.parse_json(resp) + except: + logger.error(f"JSON parse failed: {resp}") + raise RuntimeError("JSON parse fail") + + if self.prompts[id].schema: + try: + validate(instance=obj, schema=self.prompts[id].schema) + logger.debug("Schema validation successful") + except Exception as e: + raise RuntimeError(f"Schema validation fail: {e}") + + return obj + + if resp_type == "jsonl": + objects = self.parse_jsonl(resp) + + if not objects: + logger.warning("JSONL parse returned no valid objects") + return [] + + # Validate each object against schema if provided + if self.prompts[id].schema: + validated = [] + for i, obj in enumerate(objects): + try: + validate(instance=obj, schema=self.prompts[id].schema) + validated.append(obj) + except Exception as e: + logger.warning(f"Object {i} failed schema validation: {e}") + return validated + + return objects + + raise RuntimeError(f"Response type {resp_type} not known") +``` + +### Affected Prompts + +The following prompts should be migrated to JSONL format: + +| Prompt ID | Description | Type Field | +|-----------|-------------|------------| +| `extract-definitions` | Entity/definition extraction | No (single type) | +| `extract-relationships` | Relationship extraction | No (single type) | +| `extract-topics` | Topic/definition extraction | No (single type) | +| `extract-rows` | Structured row extraction | No (single type) | +| `agent-kg-extract` | Combined definition + relationship extraction | Yes: `"definition"`, `"relationship"` | +| `extract-with-ontologies` / `ontology-extract` | Ontology-based extraction | Yes: `"entity"`, `"relationship"`, `"attribute"` | + +### API Changes + +#### Client Perspective + +JSONL parsing is transparent to prompt service API callers. The parsing occurs +server-side in the prompt service, and the response is returned via the standard +`PromptResponse.object` field as a serialized JSON array. + +When clients call the prompt service (via `PromptClient.prompt()` or similar): + +- **`response-type: "json"`** with array schema → client receives Python `list` +- **`response-type: "jsonl"`** → client receives Python `list` + +From the client's perspective, both return identical data structures. The +difference is entirely in how the LLM output is parsed server-side: + +- JSON array format: Single `json.loads()` call; fails completely if truncated +- JSONL format: Line-by-line parsing; yields partial results if truncated + +This means existing client code expecting a list from extraction prompts +requires no changes when migrating prompts from JSON to JSONL format. + +#### Server Return Value + +For `response-type: "jsonl"`, the `PromptManager.invoke()` method returns a +`list[dict]` containing all successfully parsed and validated objects. This +list is then serialized to JSON for the `PromptResponse.object` field. + +#### Error Handling + +- Empty results: Returns empty list `[]` with warning log +- Partial parse failure: Returns list of successfully parsed objects with + warning logs for failures +- Complete parse failure: Returns empty list `[]` with warning logs + +This differs from `response-type: "json"` which raises `RuntimeError` on +parse failure. The lenient behavior for JSONL is intentional to provide +truncation resilience. + +### Configuration Example + +Complete prompt configuration example: + +```json +{ + "prompt": "Extract all entities and their definitions from the following text. Output one JSON object per line.\n\nText:\n{{text}}\n\nOutput format per line:\n{\"entity\": \"\", \"definition\": \"\"}", + "response-type": "jsonl", + "schema": { + "type": "object", + "properties": { + "entity": { + "type": "string", + "description": "The entity name" + }, + "definition": { + "type": "string", + "description": "A clear definition of the entity" + } + }, + "required": ["entity", "definition"] + } +} +``` + +## Security Considerations + +- **Input Validation**: JSON parsing uses standard `json.loads()` which is safe + against injection attacks +- **Schema Validation**: Uses `jsonschema.validate()` for schema enforcement +- **No New Attack Surface**: JSONL parsing is strictly safer than JSON array + parsing due to line-by-line processing + +## Performance Considerations + +- **Memory**: Line-by-line parsing uses less peak memory than loading full JSON + arrays +- **Latency**: Parsing performance is comparable to JSON array parsing +- **Validation**: Schema validation runs per-object, which adds overhead but + enables partial results on validation failure + +## Testing Strategy + +### Unit Tests + +- JSONL parsing with valid input +- JSONL parsing with empty lines +- JSONL parsing with markdown code fences +- JSONL parsing with truncated final line +- JSONL parsing with invalid JSON lines interspersed +- Schema validation with `oneOf` discriminated unions +- Backward compatibility: existing `"text"` and `"json"` prompts unchanged + +### Integration Tests + +- End-to-end extraction with JSONL prompts +- Extraction with simulated truncation (artificially limited response) +- Mixed-type extraction with type discriminator +- Ontology extraction with all three types + +### Extraction Quality Tests + +- Compare extraction results: JSONL vs JSON array format +- Verify truncation resilience: JSONL yields partial results where JSON fails + +## Migration Plan + +### Phase 1: Implementation + +1. Implement `parse_jsonl()` method in `PromptManager` +2. Extend `invoke()` to handle `response-type: "jsonl"` +3. Add unit tests + +### Phase 2: Prompt Migration + +1. Update `extract-definitions` prompt and configuration +2. Update `extract-relationships` prompt and configuration +3. Update `extract-topics` prompt and configuration +4. Update `extract-rows` prompt and configuration +5. Update `agent-kg-extract` prompt and configuration +6. Update `extract-with-ontologies` prompt and configuration + +### Phase 3: Downstream Updates + +1. Update any code consuming extraction results to handle list return type +2. Update code that categorizes mixed-type extractions by `type` field +3. Update tests that assert on extraction output format + +## Open Questions + +None at this time. + +## References + +- Current implementation: `trustgraph-flow/trustgraph/template/prompt_manager.py` +- JSON Lines specification: https://jsonlines.org/ +- JSON Schema `oneOf`: https://json-schema.org/understanding-json-schema/reference/combining.html#oneof +- Related specification: Streaming LLM Responses (`docs/tech-specs/streaming-llm-responses.md`) diff --git a/docs/tech-specs/structured-data-2.md b/docs/tech-specs/structured-data-2.md new file mode 100644 index 00000000..1b70a6c3 --- /dev/null +++ b/docs/tech-specs/structured-data-2.md @@ -0,0 +1,613 @@ +# Structured Data Technical Specification (Part 2) + +## Overview + +This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`. + +## Problem Statements + +### 1. Naming Inconsistency: "Object" vs "Row" + +The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion: + +- "Object" is an overloaded term in software (Python objects, JSON objects, etc.) +- The data being handled is fundamentally tabular - rows in tables with defined schemas +- "Row" more accurately describes the data model and aligns with database terminology + +This inconsistency appears in module names, class names, message types, and documentation. + +### 2. Row Store Query Limitations + +The current row store implementation has significant query limitations: + +**Natural Language Mismatch**: Queries struggle with real-world data variations. For example: +- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"` +- Abbreviations, case differences, and formatting variations break exact-match queries +- Users expect semantic understanding, but the store provides literal matching + +**Schema Evolution Issues**: Changing schemas causes problems: +- Existing data may not conform to updated schemas +- Table structure changes can break queries and data integrity +- No clear migration path for schema updates + +### 3. Row Embeddings Required + +Related to problem 2, the system needs vector embeddings for row data to enable: + +- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST") +- Similarity matching for fuzzy queries +- Hybrid search combining structured filters with semantic similarity +- Better natural language query support + +The embedding service was specified but not implemented. + +### 4. Row Data Ingestion Incomplete + +The structured data ingestion pipeline is not fully operational: + +- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.) +- The ingestion service that uses these prompts is not plumbed into the system +- No end-to-end path for loading pre-structured data into the row store + +## Goals + +- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations +- **Consistent Naming**: Standardize on "row" terminology throughout the codebase +- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings +- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data + +## Technical Design + +### Unified Row Storage Schema + +The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations. + +The new design uses a single unified table for all row data: + +```sql +CREATE TABLE rows ( + collection text, + schema_name text, + index_name text, + index_value frozen>, + data map, + source text, + PRIMARY KEY ((collection, schema_name, index_name), index_value) +) +``` + +#### Column Definitions + +| Column | Type | Description | +|--------|------|-------------| +| `collection` | `text` | Data collection/import identifier (from metadata) | +| `schema_name` | `text` | Name of the schema this row conforms to | +| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites | +| `index_value` | `frozen>` | Index value(s) as a list | +| `data` | `map` | Row data as key-value pairs | +| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. | + +#### Index Handling + +Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility. + +**Single-field index example:** +- Schema defines `email` as indexed +- `index_name = "email"` +- `index_value = ['foo@bar.com']` + +**Composite index example:** +- Schema defines composite index on `region` and `status` +- `index_name = "region,status"` (field names sorted and comma-joined) +- `index_value = ['US', 'active']` (values in same order as field names) + +**Primary key example:** +- Schema defines `customer_id` as primary key +- `index_name = "customer_id"` +- `index_value = ['CUST001']` + +#### Query Patterns + +All queries follow the same pattern regardless of which index is used: + +```sql +SELECT * FROM rows +WHERE collection = 'import_2024' + AND schema_name = 'customers' + AND index_name = 'email' + AND index_value = ['foo@bar.com'] +``` + +#### Design Trade-offs + +**Advantages:** +- Schema changes don't require table structure changes +- Row data is opaque to Cassandra - field additions/removals are transparent +- Consistent query pattern for all access methods +- No Cassandra secondary indexes (which can be slow at scale) +- Native Cassandra types throughout (`map`, `frozen`) + +**Trade-offs:** +- Write amplification: each row insert = N inserts (one per indexed field) +- Storage overhead from duplicated row data +- Type information stored in schema config, conversion at application layer + +#### Consistency Model + +The design accepts certain simplifications: + +1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row. + +2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed. + +### Partition Tracking and Deletion + +#### The Problem + +With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data. + +#### Partition Tracking Table + +A secondary lookup table tracks which partitions exist: + +```sql +CREATE TABLE row_partitions ( + collection text, + schema_name text, + index_name text, + PRIMARY KEY ((collection), schema_name, index_name) +) +``` + +This enables efficient discovery of partitions for deletion operations. + +#### Row Writer Behavior + +The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row: + +1. Check if `(collection, schema_name)` is in the cache +2. If not cached (first row for this pair): + - Look up the schema config to get all index names + - Insert entries into `row_partitions` for each `(collection, schema_name, index_name)` + - Add the pair to the cache +3. Proceed with writing the row data + +The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names. + +This approach ensures: +- Lookup table writes happen once per `(collection, schema_name)` pair, not per row +- The lookup table reflects the indexes that were active when data was written +- Schema changes mid-import are picked up correctly + +#### Deletion Operations + +**Delete collection:** +```sql +-- 1. Discover all partitions +SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X'; + +-- 2. Delete each partition from rows table +DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...'; +-- (repeat for each discovered partition) + +-- 3. Clean up the lookup table +DELETE FROM row_partitions WHERE collection = 'X'; +``` + +**Delete collection + schema:** +```sql +-- 1. Discover partitions for this schema +SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y'; + +-- 2. Delete each partition from rows table +DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...'; +-- (repeat for each discovered partition) + +-- 3. Clean up the lookup table entries +DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y'; +``` + +### Row Embeddings + +Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street"). + +#### Design Overview + +Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra. + +#### Qdrant Collection Structure + +One Qdrant collection per `(user, collection, schema_name, dimension)` tuple: + +- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}` +- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix) +- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist + +#### What Gets Embedded + +The text representation of index values: + +| Index Type | Example `index_value` | Text to Embed | +|------------|----------------------|---------------| +| Single-field | `['foo@bar.com']` | `"foo@bar.com"` | +| Composite | `['US', 'active']` | `"US active"` (space-joined) | + +#### Point Structure + +Each Qdrant point contains: + +```json +{ + "id": "", + "vector": [0.1, 0.2, ...], + "payload": { + "index_name": "street_name", + "index_value": ["CHESTNUT ST"], + "text": "CHESTNUT ST" + } +} +``` + +| Payload Field | Description | +|---------------|-------------| +| `index_name` | The indexed field(s) this embedding represents | +| `index_value` | The original list of values (for Cassandra lookup) | +| `text` | The text that was embedded (for debugging/display) | + +Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name. + +#### Query Flow + +1. User queries for "Chestnut Street" within user U, collection X, schema Y +2. Embed the query text +3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_` +4. Search matching Qdrant collection(s) for nearest vectors +5. Get matching points with payloads containing `index_name` and `index_value` +6. Query Cassandra: + ```sql + SELECT * FROM rows + WHERE collection = 'X' + AND schema_name = 'Y' + AND index_name = '' + AND index_value = + ``` +7. Return matched rows + +#### Optional: Filtering by Index Name + +Queries can optionally filter by `index_name` in Qdrant to search only specific fields: + +- **"Find any field matching 'Chestnut'"** → search all vectors in the collection +- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'` + +#### Architecture + +Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings): + +- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings` +- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant + +The Cassandra row writer is a separate parallel consumer: + +- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra + +All three services consume from the same flow, keeping them decoupled. This allows: +- Independent scaling of Cassandra writes vs embedding generation vs vector storage +- Embedding services can be disabled if not needed +- Failures in one service don't affect the others +- Consistent architecture with GraphRAG pipelines + +#### Write Path + +**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`: + +1. Look up the schema to find indexed fields +2. For each indexed field: + - Build the text representation of the index value + - Compute embedding via the embeddings service +3. Output a `RowEmbeddings` message containing all computed vectors + +**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`: + +1. For each embedding in the message: + - Determine Qdrant collection from `(user, collection, schema_name, dimension)` + - Create collection if needed (lazy creation on first write) + - Upsert point with vector and payload + +#### Message Types + +```python +@dataclass +class RowIndexEmbedding: + index_name: str # The indexed field name(s) + index_value: list[str] # The field value(s) + text: str # Text that was embedded + vectors: list[list[float]] # Computed embedding vectors + +@dataclass +class RowEmbeddings: + metadata: Metadata + schema_name: str + embeddings: list[RowIndexEmbedding] +``` + +#### Deletion Integration + +Qdrant collections are discovered by prefix matching on the collection name pattern: + +**Delete `(user, collection)`:** +1. List all Qdrant collections matching prefix `rows_{user}_{collection}_` +2. Delete each matching collection +3. Delete Cassandra rows partitions (as documented above) +4. Clean up `row_partitions` entries + +**Delete `(user, collection, schema_name)`:** +1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_` +2. Delete each matching collection (handles multiple dimensions) +3. Delete Cassandra rows partitions +4. Clean up `row_partitions` + +#### Module Locations + +| Stage | Module | Entry Point | +|-------|--------|-------------| +| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` | +| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` | + +### Row Embeddings Query API + +The row embeddings query is a **separate API** from the GraphQL row query service: + +| API | Purpose | Backend | +|-----|---------|---------| +| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra | +| Row Embeddings Query | Fuzzy/semantic matching | Qdrant | + +This separation keeps concerns clean: +- GraphQL service focuses on exact, structured queries +- Embeddings API handles semantic similarity +- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data + +#### Request/Response Schema + +```python +@dataclass +class RowEmbeddingsRequest: + vectors: list[list[float]] # Query vectors (pre-computed embeddings) + user: str = "" + collection: str = "" + schema_name: str = "" + index_name: str = "" # Optional: filter to specific index + limit: int = 10 # Max results per vector + +@dataclass +class RowIndexMatch: + index_name: str = "" # The matched index field(s) + index_value: list[str] = [] # The matched value(s) + text: str = "" # Original text that was embedded + score: float = 0.0 # Similarity score + +@dataclass +class RowEmbeddingsResponse: + error: Error | None = None + matches: list[RowIndexMatch] = [] +``` + +#### Query Processor + +Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant` + +Entry point: `row-embeddings-query-qdrant` + +The processor: +1. Receives `RowEmbeddingsRequest` with query vectors +2. Finds the appropriate Qdrant collection by prefix matching +3. Searches for nearest vectors with optional `index_name` filter +4. Returns `RowEmbeddingsResponse` with matching index information + +#### API Gateway Integration + +The gateway exposes row embeddings queries via the standard request/response pattern: + +| Component | Location | +|-----------|----------| +| Dispatcher | `trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py` | +| Registration | Add `"row-embeddings"` to `request_response_dispatchers` in `manager.py` | + +Flow interface name: `row-embeddings` + +Interface definition in flow blueprint: +```json +{ + "interfaces": { + "row-embeddings": { + "request": "non-persistent://tg/request/row-embeddings:{id}", + "response": "non-persistent://tg/response/row-embeddings:{id}" + } + } +} +``` + +#### Python SDK Support + +The SDK provides methods for row embeddings queries: + +```python +# Flow-scoped query (preferred) +api = Api(url) +flow = api.flow().id("default") + +# Query with text (SDK computes embeddings) +matches = flow.row_embeddings_query( + text="Chestnut Street", + collection="my_collection", + schema_name="addresses", + index_name="street_name", # Optional filter + limit=10 +) + +# Query with pre-computed vectors +matches = flow.row_embeddings_query( + vectors=[[0.1, 0.2, ...]], + collection="my_collection", + schema_name="addresses" +) + +# Each match contains: +for match in matches: + print(match.index_name) # e.g., "street_name" + print(match.index_value) # e.g., ["CHESTNUT ST"] + print(match.text) # e.g., "CHESTNUT ST" + print(match.score) # e.g., 0.95 +``` + +#### CLI Utility + +Command: `tg-invoke-row-embeddings` + +```bash +# Query by text (computes embedding automatically) +tg-invoke-row-embeddings \ + --text "Chestnut Street" \ + --collection my_collection \ + --schema addresses \ + --index street_name \ + --limit 10 + +# Query by vector file +tg-invoke-row-embeddings \ + --vectors vectors.json \ + --collection my_collection \ + --schema addresses + +# Output formats +tg-invoke-row-embeddings --text "..." --format json +tg-invoke-row-embeddings --text "..." --format table +``` + +#### Typical Usage Pattern + +The row embeddings query is typically used as part of a fuzzy-to-exact lookup flow: + +```python +# Step 1: Fuzzy search via embeddings +matches = flow.row_embeddings_query( + text="chestnut street", + collection="geo", + schema_name="streets" +) + +# Step 2: Exact lookup via GraphQL for full row data +for match in matches: + query = f''' + query {{ + streets(where: {{ {match.index_name}: {{ eq: "{match.index_value[0]}" }} }}) {{ + street_name + city + zip_code + }} + }} + ''' + rows = flow.rows_query(query, collection="geo") +``` + +This two-step pattern enables: +- Finding "CHESTNUT ST" when user searches for "Chestnut Street" +- Retrieving complete row data with all fields +- Combining semantic similarity with structured data access + +### Row Data Ingestion + +Deferred to a subsequent phase. Will be designed alongside other ingestion changes. + +## Implementation Impact + +### Current State Analysis + +The existing implementation has two main components: + +| Component | Location | Lines | Description | +|-----------|----------|-------|-------------| +| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling | +| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete | + +**Current Query Pattern:** +```sql +SELECT * FROM {keyspace}.o_{schema_name} +WHERE collection = 'X' AND email = 'foo@bar.com' +ALLOW FILTERING +``` + +**New Query Pattern:** +```sql +SELECT * FROM {keyspace}.rows +WHERE collection = 'X' AND schema_name = 'customers' + AND index_name = 'email' AND index_value = ['foo@bar.com'] +``` + +### Key Changes + +1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either: + - Become post-filtering on returned data (if still needed) + - Are removed in favor of using the embeddings API for fuzzy matching + +2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code. + +### Proposed Refactor + +The refactor has two parts: + +#### 1. Break Out GraphQL Code + +Extract reusable GraphQL components into a shared module: + +``` +trustgraph-flow/trustgraph/query/graphql/ +├── __init__.py +├── types.py # Filter types (IntFilter, StringFilter, FloatFilter) +├── schema.py # Dynamic schema generation from RowSchema +└── filters.py # Filter parsing utilities +``` + +This enables: +- Reuse across different row store backends +- Cleaner separation of concerns +- Easier testing of GraphQL logic independently + +#### 2. Implement New Table Schema + +Refactor the Cassandra-specific code to use the unified table: + +**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`): +- Single `rows` table instead of per-schema tables +- Write N copies per row (one per index) +- Register to `row_partitions` table +- Simpler table creation (one-time setup) + +**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`): +- Query the unified `rows` table +- Use extracted GraphQL module for schema generation +- Simplified filter handling (exact match only at DB level) + +### Module Renames + +As part of the "object" → "row" naming cleanup: + +| Current | New | +|---------|-----| +| `storage/objects/cassandra/` | `storage/rows/cassandra/` | +| `query/objects/cassandra/` | `query/rows/cassandra/` | +| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` | + +### New Modules + +| Module | Purpose | +|--------|---------| +| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities | +| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API | +| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) | +| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) | + +## References + +- [Structured Data Technical Specification](structured-data.md) diff --git a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml new file mode 100644 index 00000000..916b4beb --- /dev/null +++ b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml @@ -0,0 +1,39 @@ +type: object +description: | + Row embeddings query request - find similar rows by vector similarity on indexed fields. + Enables semantic/fuzzy matching on structured data. +required: + - vectors + - schema_name +properties: + vectors: + type: array + description: Query embedding vector + items: + type: number + example: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156] + schema_name: + type: string + description: Schema name to search within + example: customers + index_name: + type: string + description: Optional index name to filter search to specific index + example: full_name + limit: + type: integer + description: Maximum number of matches to return + default: 10 + minimum: 1 + maximum: 1000 + example: 20 + user: + type: string + description: User identifier + default: trustgraph + example: alice + collection: + type: string + description: Collection to search + default: default + example: sales diff --git a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml new file mode 100644 index 00000000..5654e04b --- /dev/null +++ b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml @@ -0,0 +1,53 @@ +type: object +description: Row embeddings query response with matching row index information +properties: + error: + type: object + description: Error information if query failed + properties: + type: + type: string + description: Error type identifier + example: row-embeddings-query-error + message: + type: string + description: Human-readable error message + example: Schema not found + matches: + type: array + description: List of matching row index entries with similarity scores + items: + type: object + properties: + index_name: + type: string + description: Name of the indexed field(s) + example: full_name + index_value: + type: array + description: Values of the indexed fields for this row + items: + type: string + example: ["John", "Smith"] + text: + type: string + description: The text that was embedded for this index entry + example: "John Smith" + score: + type: number + description: Similarity score (higher is more similar) + example: 0.89 +example: + matches: + - index_name: full_name + index_value: ["John", "Smith"] + text: "John Smith" + score: 0.95 + - index_name: full_name + index_value: ["Jon", "Smythe"] + text: "Jon Smythe" + score: 0.82 + - index_name: full_name + index_value: ["Jonathan", "Schmidt"] + text: "Jonathan Schmidt" + score: 0.76 diff --git a/specs/api/components/schemas/query/ObjectsQueryRequest.yaml b/specs/api/components/schemas/query/RowsQueryRequest.yaml similarity index 92% rename from specs/api/components/schemas/query/ObjectsQueryRequest.yaml rename to specs/api/components/schemas/query/RowsQueryRequest.yaml index 775bbc4b..08f03ad3 100644 --- a/specs/api/components/schemas/query/ObjectsQueryRequest.yaml +++ b/specs/api/components/schemas/query/RowsQueryRequest.yaml @@ -1,6 +1,6 @@ type: object description: | - Objects query request - GraphQL query over knowledge graph. + Rows query request - GraphQL query over structured data. required: - query properties: diff --git a/specs/api/components/schemas/query/ObjectsQueryResponse.yaml b/specs/api/components/schemas/query/RowsQueryResponse.yaml similarity index 96% rename from specs/api/components/schemas/query/ObjectsQueryResponse.yaml rename to specs/api/components/schemas/query/RowsQueryResponse.yaml index 8fd9b6a6..a8fed63d 100644 --- a/specs/api/components/schemas/query/ObjectsQueryResponse.yaml +++ b/specs/api/components/schemas/query/RowsQueryResponse.yaml @@ -1,5 +1,5 @@ type: object -description: Objects query response (GraphQL format) +description: Rows query response (GraphQL format) properties: data: description: GraphQL response data (JSON object or null) diff --git a/specs/api/openapi.yaml b/specs/api/openapi.yaml index 55c05741..4196f9ec 100644 --- a/specs/api/openapi.yaml +++ b/specs/api/openapi.yaml @@ -121,8 +121,8 @@ paths: $ref: './paths/flow/mcp-tool.yaml' /api/v1/flow/{flow}/service/triples: $ref: './paths/flow/triples.yaml' - /api/v1/flow/{flow}/service/objects: - $ref: './paths/flow/objects.yaml' + /api/v1/flow/{flow}/service/rows: + $ref: './paths/flow/rows.yaml' /api/v1/flow/{flow}/service/nlp-query: $ref: './paths/flow/nlp-query.yaml' /api/v1/flow/{flow}/service/structured-query: @@ -133,6 +133,8 @@ paths: $ref: './paths/flow/graph-embeddings.yaml' /api/v1/flow/{flow}/service/document-embeddings: $ref: './paths/flow/document-embeddings.yaml' + /api/v1/flow/{flow}/service/row-embeddings: + $ref: './paths/flow/row-embeddings.yaml' /api/v1/flow/{flow}/service/text-load: $ref: './paths/flow/text-load.yaml' /api/v1/flow/{flow}/service/document-load: diff --git a/specs/api/paths/flow/nlp-query.yaml b/specs/api/paths/flow/nlp-query.yaml index 7032b5b9..a10f3a67 100644 --- a/specs/api/paths/flow/nlp-query.yaml +++ b/specs/api/paths/flow/nlp-query.yaml @@ -34,7 +34,7 @@ post: ``` 1. User asks: "Who does Alice know?" 2. NLP Query generates GraphQL - 3. Execute via /api/v1/flow/{flow}/service/objects + 3. Execute via /api/v1/flow/{flow}/service/rows 4. Return results to user ``` diff --git a/specs/api/paths/flow/row-embeddings.yaml b/specs/api/paths/flow/row-embeddings.yaml new file mode 100644 index 00000000..05837c06 --- /dev/null +++ b/specs/api/paths/flow/row-embeddings.yaml @@ -0,0 +1,101 @@ +post: + tags: + - Flow Services + summary: Row Embeddings Query - semantic search on structured data + description: | + Query row embeddings to find similar rows by vector similarity on indexed fields. + Enables fuzzy/semantic matching on structured data. + + ## Row Embeddings Query Overview + + Find rows whose indexed field values are semantically similar to a query: + - **Input**: Query embedding vector, schema name, optional index filter + - **Search**: Compare against stored row index embeddings + - **Output**: Matching rows with index values and similarity scores + + Core component of semantic search on structured data. + + ## Use Cases + + - **Fuzzy name matching**: Find customers by approximate name + - **Semantic field search**: Find products by description similarity + - **Data deduplication**: Identify potential duplicate records + - **Entity resolution**: Match records across datasets + + ## Process + + 1. Obtain query embedding (via embeddings service) + 2. Query stored row index embeddings for the specified schema + 3. Calculate cosine similarity + 4. Return top N most similar index entries + 5. Use index values to retrieve full rows via GraphQL + + ## Response Format + + Each match includes: + - `index_name`: The indexed field(s) that matched + - `index_value`: The actual values for those fields + - `text`: The text that was embedded + - `score`: Similarity score (higher = more similar) + + operationId: rowEmbeddingsQueryService + security: + - bearerAuth: [] + parameters: + - name: flow + in: path + required: true + schema: + type: string + description: Flow instance ID + example: my-flow + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml' + examples: + basicQuery: + summary: Find similar customer names + value: + vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] + schema_name: customers + limit: 10 + user: alice + collection: sales + filteredQuery: + summary: Search specific index + value: + vectors: [0.1, -0.2, 0.3, -0.4, 0.5] + schema_name: products + index_name: description + limit: 20 + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml' + examples: + similarRows: + summary: Similar rows found + value: + matches: + - index_name: full_name + index_value: ["John", "Smith"] + text: "John Smith" + score: 0.95 + - index_name: full_name + index_value: ["Jon", "Smythe"] + text: "Jon Smythe" + score: 0.82 + - index_name: full_name + index_value: ["Jonathan", "Schmidt"] + text: "Jonathan Schmidt" + score: 0.76 + '401': + $ref: '../../components/responses/Unauthorized.yaml' + '500': + $ref: '../../components/responses/Error.yaml' diff --git a/specs/api/paths/flow/objects.yaml b/specs/api/paths/flow/rows.yaml similarity index 90% rename from specs/api/paths/flow/objects.yaml rename to specs/api/paths/flow/rows.yaml index ac94a353..d648c9db 100644 --- a/specs/api/paths/flow/objects.yaml +++ b/specs/api/paths/flow/rows.yaml @@ -1,19 +1,19 @@ post: tags: - Flow Services - summary: Objects query - GraphQL over knowledge graph + summary: Rows query - GraphQL over structured data description: | - Query knowledge graph using GraphQL for object-oriented data access. + Query structured data using GraphQL for row-oriented data access. - ## Objects Query Overview + ## Rows Query Overview - GraphQL interface to knowledge graph: + GraphQL interface to structured data: - **Schema-driven**: Predefined types and relationships - **Flexible queries**: Request exactly what you need - **Nested data**: Traverse relationships in single query - **Type-safe**: Strong typing with introspection - Abstracts RDF triples into familiar object model. + Abstracts structured rows into familiar object model. ## GraphQL Benefits @@ -61,7 +61,7 @@ post: Schema defines available types via config service. Use introspection query to discover schema. - operationId: objectsQueryService + operationId: rowsQueryService security: - bearerAuth: [] parameters: @@ -77,7 +77,7 @@ post: content: application/json: schema: - $ref: '../../components/schemas/query/ObjectsQueryRequest.yaml' + $ref: '../../components/schemas/query/RowsQueryRequest.yaml' examples: simpleQuery: summary: Simple query @@ -129,7 +129,7 @@ post: content: application/json: schema: - $ref: '../../components/schemas/query/ObjectsQueryResponse.yaml' + $ref: '../../components/schemas/query/RowsQueryResponse.yaml' examples: successfulQuery: summary: Successful query diff --git a/specs/api/paths/flow/structured-query.yaml b/specs/api/paths/flow/structured-query.yaml index c094c50a..6d4dfe87 100644 --- a/specs/api/paths/flow/structured-query.yaml +++ b/specs/api/paths/flow/structured-query.yaml @@ -9,7 +9,7 @@ post: Combines two operations in one call: 1. **NLP Query**: Generate GraphQL from question - 2. **Objects Query**: Execute generated query + 2. **Rows Query**: Execute generated query 3. **Return Results**: Direct answer data Simplest way to query knowledge graph with natural language. @@ -21,7 +21,7 @@ post: - **Output**: Query results (data) - **Use when**: Want simple, direct answers - ### NLP Query + Objects Query (separate calls) + ### NLP Query + Rows Query (separate calls) - **Step 1**: Convert question → GraphQL - **Step 2**: Execute GraphQL → results - **Use when**: Need to inspect/modify query before execution diff --git a/specs/websocket/components/messages/ServiceRequest.yaml b/specs/websocket/components/messages/ServiceRequest.yaml index 8df44caa..26db079e 100644 --- a/specs/websocket/components/messages/ServiceRequest.yaml +++ b/specs/websocket/components/messages/ServiceRequest.yaml @@ -25,12 +25,13 @@ payload: - $ref: './requests/EmbeddingsRequest.yaml' - $ref: './requests/McpToolRequest.yaml' - $ref: './requests/TriplesRequest.yaml' - - $ref: './requests/ObjectsRequest.yaml' + - $ref: './requests/RowsRequest.yaml' - $ref: './requests/NlpQueryRequest.yaml' - $ref: './requests/StructuredQueryRequest.yaml' - $ref: './requests/StructuredDiagRequest.yaml' - $ref: './requests/GraphEmbeddingsRequest.yaml' - $ref: './requests/DocumentEmbeddingsRequest.yaml' + - $ref: './requests/RowEmbeddingsRequest.yaml' - $ref: './requests/TextLoadRequest.yaml' - $ref: './requests/DocumentLoadRequest.yaml' diff --git a/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml new file mode 100644 index 00000000..8010417d --- /dev/null +++ b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml @@ -0,0 +1,30 @@ +type: object +description: WebSocket request for row-embeddings service (flow-hosted service) +required: + - id + - service + - flow + - request +properties: + id: + type: string + description: Unique request identifier + service: + type: string + const: row-embeddings + description: Service identifier for row-embeddings service + flow: + type: string + description: Flow ID + request: + $ref: '../../../../api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml' +examples: + - id: req-1 + service: row-embeddings + flow: my-flow + request: + vectors: [0.023, -0.142, 0.089, 0.234] + schema_name: customers + limit: 10 + user: trustgraph + collection: default diff --git a/specs/websocket/components/messages/requests/ObjectsRequest.yaml b/specs/websocket/components/messages/requests/RowsRequest.yaml similarity index 60% rename from specs/websocket/components/messages/requests/ObjectsRequest.yaml rename to specs/websocket/components/messages/requests/RowsRequest.yaml index 61c9ef64..8eaa0919 100644 --- a/specs/websocket/components/messages/requests/ObjectsRequest.yaml +++ b/specs/websocket/components/messages/requests/RowsRequest.yaml @@ -1,5 +1,5 @@ type: object -description: WebSocket request for objects service (flow-hosted service) +description: WebSocket request for rows service (flow-hosted service) required: - id - service @@ -11,16 +11,16 @@ properties: description: Unique request identifier service: type: string - const: objects - description: Service identifier for objects service + const: rows + description: Service identifier for rows service flow: type: string description: Flow ID request: - $ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml' + $ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml' examples: - id: req-1 - service: objects + service: rows flow: my-flow request: query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }" diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index 3d184d3d..e82ccd98 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -15,10 +15,10 @@ from trustgraph.schema import ( TextCompletionRequest, TextCompletionResponse, DocumentRagQuery, DocumentRagResponse, AgentRequest, AgentResponse, AgentStep, - Chunk, Triple, Triples, Value, Error, + Chunk, Triple, Triples, Term, Error, EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings, - Metadata + Metadata, IRI, LITERAL ) @@ -43,7 +43,7 @@ def schema_registry(): "Chunk": Chunk, "Triple": Triple, "Triples": Triples, - "Value": Value, + "Term": Term, "Error": Error, "EntityContext": EntityContext, "EntityContexts": EntityContexts, @@ -98,26 +98,22 @@ def sample_message_data(): "collection": "test_collection", "metadata": [] }, - "Value": { - "value": "http://example.com/entity", - "is_uri": True, - "type": "" + "Term": { + "type": IRI, + "iri": "http://example.com/entity" }, "Triple": { - "s": Value( - value="http://example.com/subject", - is_uri=True, - type="" + "s": Term( + type=IRI, + iri="http://example.com/subject" ), - "p": Value( - value="http://example.com/predicate", - is_uri=True, - type="" + "p": Term( + type=IRI, + iri="http://example.com/predicate" ), - "o": Value( - value="Object value", - is_uri=False, - type="" + "o": Term( + type=LITERAL, + value="Object value" ) } } @@ -139,10 +135,10 @@ def invalid_message_data(): {"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit {"query": "test"}, # Missing required fields ], - "Value": [ - {"value": None, "is_uri": True, "type": ""}, # Invalid value (None) - {"value": "test", "is_uri": "not_boolean", "type": ""}, # Invalid is_uri - {"value": 123, "is_uri": True, "type": ""}, # Invalid value (not string) + "Term": [ + {"type": IRI, "iri": None}, # Invalid iri (None) + {"type": "invalid_type", "value": "test"}, # Invalid type + {"type": LITERAL, "value": 123}, # Invalid value (not string) ] } diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 6b10bd2f..746ebaed 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -15,14 +15,14 @@ from trustgraph.schema import ( TextCompletionRequest, TextCompletionResponse, DocumentRagQuery, DocumentRagResponse, AgentRequest, AgentResponse, AgentStep, - Chunk, Triple, Triples, Value, Error, + Chunk, Triple, Triples, Term, Error, EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings, Metadata, Field, RowSchema, StructuredDataSubmission, ExtractedObject, QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, StructuredQueryRequest, StructuredQueryResponse, - StructuredObjectEmbedding + StructuredObjectEmbedding, IRI, LITERAL ) from .conftest import validate_schema_contract, serialize_deserialize_test @@ -271,52 +271,51 @@ class TestAgentMessageContracts: class TestGraphMessageContracts: """Contract tests for Graph/Knowledge message schemas""" - def test_value_schema_contract(self, sample_message_data): - """Test Value schema contract""" + def test_term_schema_contract(self, sample_message_data): + """Test Term schema contract""" # Arrange - value_data = sample_message_data["Value"] + term_data = sample_message_data["Term"] # Act & Assert - assert validate_schema_contract(Value, value_data) - - # Test URI value - uri_value = Value(**value_data) - assert uri_value.value == "http://example.com/entity" - assert uri_value.is_uri is True + assert validate_schema_contract(Term, term_data) - # Test literal value - literal_value = Value( - value="Literal text value", - is_uri=False, - type="" + # Test URI term + uri_term = Term(**term_data) + assert uri_term.iri == "http://example.com/entity" + assert uri_term.type == IRI + + # Test literal term + literal_term = Term( + type=LITERAL, + value="Literal text value" ) - assert literal_value.value == "Literal text value" - assert literal_value.is_uri is False + assert literal_term.value == "Literal text value" + assert literal_term.type == LITERAL def test_triple_schema_contract(self, sample_message_data): """Test Triple schema contract""" # Arrange triple_data = sample_message_data["Triple"] - # Act & Assert - Triple uses Value objects, not dict validation + # Act & Assert - Triple uses Term objects, not dict validation triple = Triple( s=triple_data["s"], - p=triple_data["p"], + p=triple_data["p"], o=triple_data["o"] ) - assert triple.s.value == "http://example.com/subject" - assert triple.p.value == "http://example.com/predicate" + assert triple.s.iri == "http://example.com/subject" + assert triple.p.iri == "http://example.com/predicate" assert triple.o.value == "Object value" - assert triple.s.is_uri is True - assert triple.p.is_uri is True - assert triple.o.is_uri is False + assert triple.s.type == IRI + assert triple.p.type == IRI + assert triple.o.type == LITERAL def test_triples_schema_contract(self, sample_message_data): """Test Triples (batch) schema contract""" # Arrange metadata = Metadata(**sample_message_data["Metadata"]) triple = Triple(**sample_message_data["Triple"]) - + triples_data = { "metadata": metadata, "triples": [triple] @@ -324,11 +323,11 @@ class TestGraphMessageContracts: # Act & Assert assert validate_schema_contract(Triples, triples_data) - + triples = Triples(**triples_data) assert triples.metadata.id == "test-doc-123" assert len(triples.triples) == 1 - assert triples.triples[0].s.value == "http://example.com/subject" + assert triples.triples[0].s.iri == "http://example.com/subject" def test_chunk_schema_contract(self, sample_message_data): """Test Chunk schema contract""" @@ -349,29 +348,29 @@ class TestGraphMessageContracts: def test_entity_context_schema_contract(self): """Test EntityContext schema contract""" # Arrange - entity_value = Value(value="http://example.com/entity", is_uri=True, type="") + entity_term = Term(type=IRI, iri="http://example.com/entity") entity_context_data = { - "entity": entity_value, + "entity": entity_term, "context": "Context information about the entity" } # Act & Assert assert validate_schema_contract(EntityContext, entity_context_data) - + entity_context = EntityContext(**entity_context_data) - assert entity_context.entity.value == "http://example.com/entity" + assert entity_context.entity.iri == "http://example.com/entity" assert entity_context.context == "Context information about the entity" def test_entity_contexts_batch_schema_contract(self, sample_message_data): """Test EntityContexts (batch) schema contract""" # Arrange metadata = Metadata(**sample_message_data["Metadata"]) - entity_value = Value(value="http://example.com/entity", is_uri=True, type="") + entity_term = Term(type=IRI, iri="http://example.com/entity") entity_context = EntityContext( - entity=entity_value, + entity=entity_term, context="Entity context" ) - + entity_contexts_data = { "metadata": metadata, "entities": [entity_context] @@ -379,7 +378,7 @@ class TestGraphMessageContracts: # Act & Assert assert validate_schema_contract(EntityContexts, entity_contexts_data) - + entity_contexts = EntityContexts(**entity_contexts_data) assert entity_contexts.metadata.id == "test-doc-123" assert len(entity_contexts.entities) == 1 @@ -417,10 +416,10 @@ class TestMetadataMessageContracts: # Act & Assert assert validate_schema_contract(Metadata, metadata_data) - + metadata = Metadata(**metadata_data) assert len(metadata.metadata) == 1 - assert metadata.metadata[0].s.value == "http://example.com/subject" + assert metadata.metadata[0].s.iri == "http://example.com/subject" def test_error_schema_contract(self): """Test Error schema contract""" @@ -532,7 +531,7 @@ class TestSerializationContracts: # Test each schema in the registry for schema_name, schema_class in schema_registry.items(): if schema_name in sample_message_data: - # Skip Triple schema as it requires special handling with Value objects + # Skip Triple schema as it requires special handling with Term objects if schema_name == "Triple": continue @@ -541,36 +540,36 @@ class TestSerializationContracts: assert serialize_deserialize_test(schema_class, data), f"Serialization failed for {schema_name}" def test_triple_serialization_contract(self, sample_message_data): - """Test Triple schema serialization contract with Value objects""" + """Test Triple schema serialization contract with Term objects""" # Arrange triple_data = sample_message_data["Triple"] - + # Act triple = Triple( s=triple_data["s"], - p=triple_data["p"], + p=triple_data["p"], o=triple_data["o"] ) - - # Assert - Test that Value objects are properly constructed and accessible - assert triple.s.value == "http://example.com/subject" - assert triple.p.value == "http://example.com/predicate" + + # Assert - Test that Term objects are properly constructed and accessible + assert triple.s.iri == "http://example.com/subject" + assert triple.p.iri == "http://example.com/predicate" assert triple.o.value == "Object value" - assert isinstance(triple.s, Value) - assert isinstance(triple.p, Value) - assert isinstance(triple.o, Value) + assert isinstance(triple.s, Term) + assert isinstance(triple.p, Term) + assert isinstance(triple.o, Term) def test_nested_schema_serialization_contract(self, sample_message_data): """Test serialization of nested schemas""" # Test Triples (contains Metadata and Triple objects) metadata = Metadata(**sample_message_data["Metadata"]) triple = Triple(**sample_message_data["Triple"]) - + triples = Triples(metadata=metadata, triples=[triple]) - + # Verify nested objects maintain their contracts assert triples.metadata.id == "test-doc-123" - assert triples.triples[0].s.value == "http://example.com/subject" + assert triples.triples[0].s.iri == "http://example.com/subject" def test_array_field_serialization_contract(self): """Test serialization of array fields""" diff --git a/tests/contract/test_objects_cassandra_contracts.py b/tests/contract/test_rows_cassandra_contracts.py similarity index 83% rename from tests/contract/test_objects_cassandra_contracts.py rename to tests/contract/test_rows_cassandra_contracts.py index bb8aec8a..d1a8ba26 100644 --- a/tests/contract/test_objects_cassandra_contracts.py +++ b/tests/contract/test_rows_cassandra_contracts.py @@ -1,8 +1,8 @@ """ -Contract tests for Cassandra Object Storage +Contract tests for Cassandra Row Storage These tests verify the message contracts and schema compatibility -for the objects storage processor. +for the rows storage processor. """ import pytest @@ -10,12 +10,12 @@ import json from pulsar.schema import AvroSchema from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field -from trustgraph.storage.objects.cassandra.write import Processor +from trustgraph.storage.rows.cassandra.write import Processor @pytest.mark.contract -class TestObjectsCassandraContracts: - """Contract tests for Cassandra object storage messages""" +class TestRowsCassandraContracts: + """Contract tests for Cassandra row storage messages""" def test_extracted_object_input_contract(self): """Test that ExtractedObject schema matches expected input format""" @@ -145,50 +145,6 @@ class TestObjectsCassandraContracts: assert required_field_keys.issubset(field.keys()) assert set(field.keys()).issubset(required_field_keys | optional_field_keys) - def test_cassandra_type_mapping_contract(self): - """Test that all supported field types have Cassandra mappings""" - processor = Processor.__new__(Processor) - - # All field types that should be supported - supported_types = [ - ("string", "text"), - ("integer", "int"), # or bigint based on size - ("float", "float"), # or double based on size - ("boolean", "boolean"), - ("timestamp", "timestamp"), - ("date", "date"), - ("time", "time"), - ("uuid", "uuid") - ] - - for field_type, expected_cassandra_type in supported_types: - cassandra_type = processor.get_cassandra_type(field_type) - # For integer and float, the exact type depends on size - if field_type in ["integer", "float"]: - assert cassandra_type in ["int", "bigint", "float", "double"] - else: - assert cassandra_type == expected_cassandra_type - - def test_value_conversion_contract(self): - """Test value conversion for all supported types""" - processor = Processor.__new__(Processor) - - # Test conversions maintain data integrity - test_cases = [ - # (input_value, field_type, expected_output, expected_type) - ("123", "integer", 123, int), - ("123.45", "float", 123.45, float), - ("true", "boolean", True, bool), - ("false", "boolean", False, bool), - ("test string", "string", "test string", str), - (None, "string", None, type(None)), - ] - - for input_val, field_type, expected_val, expected_type in test_cases: - result = processor.convert_value(input_val, field_type) - assert result == expected_val - assert isinstance(result, expected_type) or result is None - @pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type") def test_extracted_object_serialization_contract(self): """Test that ExtractedObject can be serialized/deserialized correctly""" @@ -222,43 +178,31 @@ class TestObjectsCassandraContracts: assert decoded.confidence == original.confidence assert decoded.source_span == original.source_span - def test_cassandra_table_naming_contract(self): + def test_cassandra_name_sanitization_contract(self): """Test Cassandra naming conventions and constraints""" processor = Processor.__new__(Processor) - - # Test table naming (always gets o_ prefix) - table_test_names = [ - ("simple_name", "o_simple_name"), - ("Name-With-Dashes", "o_name_with_dashes"), - ("name.with.dots", "o_name_with_dots"), - ("123_numbers", "o_123_numbers"), - ("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores - ("UPPERCASE", "o_uppercase"), - ("CamelCase", "o_camelcase"), - ("", "o_"), # Edge case - empty string becomes o_ - ] - - for input_name, expected_name in table_test_names: - result = processor.sanitize_table(input_name) - assert result == expected_name - # Verify result is valid Cassandra identifier (starts with letter) - assert result.startswith('o_') - assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_' - - # Test regular name sanitization (only adds o_ prefix if starts with number) + + # Test name sanitization for Cassandra identifiers + # - Non-alphanumeric chars (except underscore) become underscores + # - Names starting with non-letter get 'r_' prefix + # - All names converted to lowercase name_test_cases = [ ("simple_name", "simple_name"), ("Name-With-Dashes", "name_with_dashes"), ("name.with.dots", "name_with_dots"), - ("123_numbers", "o_123_numbers"), # Only this gets o_ prefix + ("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number) ("special!@#chars", "special___chars"), # 3 special chars become 3 underscores ("UPPERCASE", "uppercase"), ("CamelCase", "camelcase"), + ("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore) ] - + for input_name, expected_name in name_test_cases: result = processor.sanitize_name(input_name) - assert result == expected_name + assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}" + # Verify result is valid Cassandra identifier (starts with letter) + if result: # Skip empty string case + assert result[0].isalpha(), f"Result {result} should start with a letter" def test_primary_key_structure_contract(self): """Test that primary key structure follows Cassandra best practices""" @@ -308,8 +252,8 @@ class TestObjectsCassandraContracts: @pytest.mark.contract -class TestObjectsCassandraContractsBatch: - """Contract tests for Cassandra object storage batch processing""" +class TestRowsCassandraContractsBatch: + """Contract tests for Cassandra row storage batch processing""" def test_extracted_object_batch_input_contract(self): """Test that batched ExtractedObject schema matches expected input format""" diff --git a/tests/contract/test_objects_graphql_query_contracts.py b/tests/contract/test_rows_graphql_query_contracts.py similarity index 89% rename from tests/contract/test_objects_graphql_query_contracts.py rename to tests/contract/test_rows_graphql_query_contracts.py index ceb9dc17..db796306 100644 --- a/tests/contract/test_objects_graphql_query_contracts.py +++ b/tests/contract/test_rows_graphql_query_contracts.py @@ -1,26 +1,26 @@ """ -Contract tests for Objects GraphQL Query Service +Contract tests for Rows GraphQL Query Service These tests verify the message contracts and schema compatibility -for the objects GraphQL query processor. +for the rows GraphQL query processor. """ import pytest import json from pulsar.schema import AvroSchema -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError -from trustgraph.query.objects.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor @pytest.mark.contract -class TestObjectsGraphQLQueryContracts: +class TestRowsGraphQLQueryContracts: """Contract tests for GraphQL query service messages""" - def test_objects_query_request_contract(self): - """Test ObjectsQueryRequest schema structure and required fields""" + def test_rows_query_request_contract(self): + """Test RowsQueryRequest schema structure and required fields""" # Create test request with all required fields - test_request = ObjectsQueryRequest( + test_request = RowsQueryRequest( user="test_user", collection="test_collection", query='{ customers { id name email } }', @@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts: assert test_request.variables["status"] == "active" assert test_request.operation_name == "GetCustomers" - def test_objects_query_request_minimal(self): - """Test ObjectsQueryRequest with minimal required fields""" + def test_rows_query_request_minimal(self): + """Test RowsQueryRequest with minimal required fields""" # Create request with only essential fields - minimal_request = ObjectsQueryRequest( + minimal_request = RowsQueryRequest( user="user", collection="collection", query='{ test }', @@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts: assert test_error.path == ["customers", "0", "nonexistent"] assert test_error.extensions["code"] == "FIELD_ERROR" - def test_objects_query_response_success_contract(self): - """Test ObjectsQueryResponse schema for successful queries""" + def test_rows_query_response_success_contract(self): + """Test RowsQueryResponse schema for successful queries""" # Create successful response - success_response = ObjectsQueryResponse( + success_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}', errors=[], @@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts: assert len(parsed_data["customers"]) == 1 assert parsed_data["customers"][0]["id"] == "1" - def test_objects_query_response_error_contract(self): - """Test ObjectsQueryResponse schema for error cases""" + def test_rows_query_response_error_contract(self): + """Test RowsQueryResponse schema for error cases""" # Create GraphQL errors - work around Pulsar Array(Record) validation bug # by creating a response without the problematic errors array first - error_response = ObjectsQueryResponse( + error_response = RowsQueryResponse( error=None, # System error is None - these are GraphQL errors data=None, # No data due to errors errors=[], # Empty errors array to avoid Pulsar bug @@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts: assert validation_error.path == ["customers", "email"] assert validation_error.extensions["details"] == "Invalid email format" - def test_objects_query_response_system_error_contract(self): - """Test ObjectsQueryResponse schema for system errors""" + def test_rows_query_response_system_error_contract(self): + """Test RowsQueryResponse schema for system errors""" from trustgraph.schema import Error # Create system error response - system_error_response = ObjectsQueryResponse( + system_error_response = RowsQueryResponse( error=Error( - type="objects-query-error", + type="rows-query-error", message="Failed to connect to Cassandra cluster" ), data=None, @@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts: # Verify system error structure assert system_error_response.error is not None - assert system_error_response.error.type == "objects-query-error" + assert system_error_response.error.type == "rows-query-error" assert "Cassandra" in system_error_response.error.message assert system_error_response.data is None assert len(system_error_response.errors) == 0 @@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts: def test_request_response_serialization_contract(self): """Test that request/response can be serialized/deserialized correctly""" # Create original request - original_request = ObjectsQueryRequest( + original_request = RowsQueryRequest( user="serialization_test", collection="test_data", query='{ orders(limit: 5) { id total customer { name } } }', @@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts: ) # Test request serialization using Pulsar schema - request_schema = AvroSchema(ObjectsQueryRequest) + request_schema = AvroSchema(RowsQueryRequest) # Encode and decode request encoded_request = request_schema.encode(original_request) @@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts: assert decoded_request.operation_name == original_request.operation_name # Create original response - work around Pulsar Array(Record) bug - original_response = ObjectsQueryResponse( + original_response = RowsQueryResponse( error=None, data='{"orders": []}', errors=[], # Empty to avoid Pulsar validation bug @@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts: ) # Test response serialization - response_schema = AvroSchema(ObjectsQueryResponse) + response_schema = AvroSchema(RowsQueryResponse) # Encode and decode response encoded_response = response_schema.encode(original_response) @@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts: def test_graphql_query_format_contract(self): """Test supported GraphQL query formats""" # Test basic query - basic_query = ObjectsQueryRequest( + basic_query = RowsQueryRequest( user="test", collection="test", query='{ customers { id } }', variables={}, operation_name="" ) @@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts: assert basic_query.query.strip().endswith('}') # Test query with variables - parameterized_query = ObjectsQueryRequest( + parameterized_query = RowsQueryRequest( user="test", collection="test", query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }', variables={"status": "active", "limit": "10"}, @@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts: assert parameterized_query.operation_name == "GetCustomers" # Test complex nested query - nested_query = ObjectsQueryRequest( + nested_query = RowsQueryRequest( user="test", collection="test", query=''' { @@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts: # Note: Current schema uses Map(String()) which only supports string values # This test verifies the current contract, though ideally we'd support all JSON types - variables_test = ObjectsQueryRequest( + variables_test = RowsQueryRequest( user="test", collection="test", query='{ test }', variables={ "string_var": "test_value", @@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts: def test_cassandra_context_fields_contract(self): """Test that request contains necessary fields for Cassandra operations""" # Verify request has fields needed for Cassandra keyspace/table targeting - request = ObjectsQueryRequest( + request = RowsQueryRequest( user="keyspace_name", # Maps to Cassandra keyspace collection="partition_collection", # Used in partition key query='{ objects { id } }', @@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts: def test_graphql_extensions_contract(self): """Test GraphQL extensions field format and usage""" # Extensions should support query metadata - response_with_extensions = ObjectsQueryResponse( + response_with_extensions = RowsQueryResponse( error=None, data='{"test": "data"}', errors=[], @@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts: ''' # Request to execute specific operation - multi_op_request = ObjectsQueryRequest( + multi_op_request = RowsQueryRequest( user="test", collection="test", query=multi_op_query, variables={}, @@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts: assert "GetOrders" in multi_op_request.query # Test single operation (operation_name optional) - single_op_request = ObjectsQueryRequest( + single_op_request = RowsQueryRequest( user="test", collection="test", query='{ customers { id } }', variables={}, operation_name="" diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index 91707d4d..71ccd787 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -15,7 +15,7 @@ from trustgraph.schema import ( QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, StructuredQueryRequest, StructuredQueryResponse, StructuredObjectEmbedding, Field, RowSchema, - Metadata, Error, Value + Metadata, Error ) from .conftest import serialize_deserialize_test diff --git a/tests/integration/test_agent_kg_extraction_integration.py b/tests/integration/test_agent_kg_extraction_integration.py index 50aadf3b..849547c8 100644 --- a/tests/integration/test_agent_kg_extraction_integration.py +++ b/tests/integration/test_agent_kg_extraction_integration.py @@ -12,7 +12,7 @@ import json from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor -from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error +from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from trustgraph.template.prompt_manager import PromptManager @@ -30,38 +30,16 @@ class TestAgentKgExtractionIntegration: # Mock agent client agent_client = AsyncMock() - # Mock successful agent response + # Mock successful agent response in JSONL format def mock_agent_response(recipient, question): - # Simulate agent processing and return structured response + # Simulate agent processing and return structured JSONL response mock_response = MagicMock() mock_response.error = None mock_response.answer = '''```json -{ - "definitions": [ - { - "entity": "Machine Learning", - "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." - }, - { - "entity": "Neural Networks", - "definition": "Computing systems inspired by biological neural networks that process information." - } - ], - "relationships": [ - { - "subject": "Machine Learning", - "predicate": "is_subset_of", - "object": "Artificial Intelligence", - "object-entity": true - }, - { - "subject": "Neural Networks", - "predicate": "used_in", - "object": "Machine Learning", - "object-entity": true - } - ] -} +{"type": "definition", "entity": "Machine Learning", "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."} +{"type": "definition", "entity": "Neural Networks", "definition": "Computing systems inspired by biological neural networks that process information."} +{"type": "relationship", "subject": "Machine Learning", "predicate": "is_subset_of", "object": "Artificial Intelligence", "object-entity": true} +{"type": "relationship", "subject": "Neural Networks", "predicate": "used_in", "object": "Machine Learning", "object-entity": true} ```''' return mock_response.answer @@ -100,9 +78,9 @@ class TestAgentKgExtractionIntegration: id="doc123", metadata=[ Triple( - s=Value(value="doc123", is_uri=True), - p=Value(value="http://example.org/type", is_uri=True), - o=Value(value="document", is_uri=False) + s=Term(type=IRI, iri="doc123"), + p=Term(type=IRI, iri="http://example.org/type"), + o=Term(type=LITERAL, value="document") ) ] ) @@ -120,7 +98,7 @@ class TestAgentKgExtractionIntegration: # Copy the methods we want to test extractor.to_uri = real_extractor.to_uri - extractor.parse_json = real_extractor.parse_json + extractor.parse_jsonl = real_extractor.parse_jsonl extractor.process_extraction_data = real_extractor.process_extraction_data extractor.emit_triples = real_extractor.emit_triples extractor.emit_entity_contexts = real_extractor.emit_entity_contexts @@ -156,7 +134,7 @@ class TestAgentKgExtractionIntegration: agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt) # Parse and process - extraction_data = extractor.parse_json(agent_response) + extraction_data = extractor.parse_jsonl(agent_response) triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata) # Add metadata triples @@ -200,15 +178,15 @@ class TestAgentKgExtractionIntegration: assert len(sent_triples.triples) > 0 # Check that we have definition triples - definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION] + definition_triples = [t for t in sent_triples.triples if t.p.iri == DEFINITION] assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks - + # Check that we have label triples - label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL] + label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL] assert len(label_triples) >= 2 # Should have labels for entities - + # Check subject-of relationships - subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF] + subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF] assert len(subject_of_triples) >= 2 # Entities should be linked to document # Verify entity contexts were emitted @@ -220,7 +198,7 @@ class TestAgentKgExtractionIntegration: assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities # Verify entity URIs are properly formed - entity_uris = [ec.entity.value for ec in sent_contexts.entities] + entity_uris = [ec.entity.iri for ec in sent_contexts.entities] assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris @@ -248,22 +226,28 @@ class TestAgentKgExtractionIntegration: @pytest.mark.asyncio async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context): - """Test handling of invalid JSON responses from agent""" + """Test handling of invalid JSON responses from agent - JSONL is lenient and skips invalid lines""" # Arrange - mock invalid JSON response agent_client = mock_flow_context("agent-request") - + def mock_invalid_json_response(recipient, question): return "This is not valid JSON at all" - + agent_client.invoke = mock_invalid_json_response - + mock_message = MagicMock() mock_message.value.return_value = sample_chunk mock_consumer = MagicMock() - # Act & Assert - with pytest.raises((ValueError, json.JSONDecodeError)): - await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context) + # Act - JSONL parsing is lenient, invalid lines are skipped + await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context) + + # Assert - should emit triples (with just metadata) but no entity contexts + triples_publisher = mock_flow_context("triples") + triples_publisher.send.assert_called_once() + + entity_contexts_publisher = mock_flow_context("entity-contexts") + entity_contexts_publisher.send.assert_not_called() @pytest.mark.asyncio async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context): @@ -272,7 +256,8 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") def mock_empty_response(recipient, question): - return '{"definitions": [], "relationships": []}' + # Return empty JSONL (just empty/whitespace) + return '' agent_client.invoke = mock_empty_response @@ -303,7 +288,8 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") def mock_malformed_response(recipient, question): - return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}''' + # JSONL with definition missing required field + return '{"type": "definition", "entity": "Missing Definition"}' agent_client.invoke = mock_malformed_response @@ -330,7 +316,7 @@ class TestAgentKgExtractionIntegration: def capture_prompt(recipient, question): # Verify the prompt contains the test text assert test_text in question - return '{"definitions": [], "relationships": []}' + return '' # Empty JSONL response agent_client.invoke = capture_prompt @@ -361,7 +347,7 @@ class TestAgentKgExtractionIntegration: responses = [] def mock_response(recipient, question): - response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}' + response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}' responses.append(response) return response @@ -398,7 +384,7 @@ class TestAgentKgExtractionIntegration: # Verify unicode text was properly decoded and included assert "学习机器" in question assert "人工知能" in question - return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}''' + return '{"type": "definition", "entity": "機械学習", "definition": "人工知能の一分野"}' agent_client.invoke = mock_unicode_response @@ -415,7 +401,7 @@ class TestAgentKgExtractionIntegration: sent_triples = triples_publisher.send.call_args[0][0] # Check that unicode entity was properly processed - entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"] + entity_labels = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL and t.o.value == "機械学習"] assert len(entity_labels) > 0 @pytest.mark.asyncio @@ -433,7 +419,7 @@ class TestAgentKgExtractionIntegration: def mock_large_text_response(recipient, question): # Verify large text was included assert len(question) > 10000 - return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}''' + return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}' agent_client.invoke = mock_large_text_response diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index a06ec509..6c83fb05 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -12,7 +12,7 @@ from argparse import ArgumentParser # Import processors that use Cassandra configuration from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter -from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter +from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery from trustgraph.storage.knowledge.store import Processor as KgStore @@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow: assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3'] assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster): """Test complete flow from environment variables to Cassandra Cluster connection.""" env_vars = { @@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) # Trigger Cassandra connection processor.connect_cassandra() @@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd: class TestMultipleHostsHandling: """Test multiple Cassandra hosts handling end-to-end.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_multiple_hosts_passed_to_cluster(self, mock_cluster): """Test that multiple hosts are correctly passed to Cassandra cluster.""" env_vars = { @@ -333,7 +333,7 @@ class TestMultipleHostsHandling: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify all hosts were passed to Cluster @@ -386,8 +386,8 @@ class TestMultipleHostsHandling: class TestAuthenticationFlow: """Test authentication configuration flow end-to-end.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster): """Test that authentication is enabled when both username and password are provided.""" env_vars = { @@ -402,7 +402,7 @@ class TestAuthenticationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Auth provider should be created @@ -416,8 +416,8 @@ class TestAuthenticationFlow: assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): """Test that authentication is not used when credentials are missing.""" env_vars = { @@ -429,7 +429,7 @@ class TestAuthenticationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Auth provider should not be created @@ -439,11 +439,11 @@ class TestAuthenticationFlow: call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): """Test that authentication is not used when only username is provided.""" - processor = ObjectsWriter( + processor = RowsWriter( taskgroup=MagicMock(), cassandra_host='partial-auth-host', cassandra_username='partial-user' diff --git a/tests/integration/test_cassandra_integration.py b/tests/integration/test_cassandra_integration.py index 560f3132..2f5a4195 100644 --- a/tests/integration/test_cassandra_integration.py +++ b/tests/integration/test_cassandra_integration.py @@ -16,7 +16,7 @@ from .cassandra_test_helper import cassandra_container from trustgraph.direct.cassandra_kg import KnowledgeGraph from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor -from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest +from trustgraph.schema import Triple, Term, Metadata, Triples, TriplesQueryRequest, IRI, LITERAL @pytest.mark.integration @@ -118,19 +118,19 @@ class TestCassandraIntegration: metadata=Metadata(user="testuser", collection="testcol"), triples=[ Triple( - s=Value(value="http://example.org/person1", is_uri=True), - p=Value(value="http://example.org/name", is_uri=True), - o=Value(value="Alice Smith", is_uri=False) + s=Term(type=IRI, iri="http://example.org/person1"), + p=Term(type=IRI, iri="http://example.org/name"), + o=Term(type=LITERAL, value="Alice Smith") ), Triple( - s=Value(value="http://example.org/person1", is_uri=True), - p=Value(value="http://example.org/age", is_uri=True), - o=Value(value="25", is_uri=False) + s=Term(type=IRI, iri="http://example.org/person1"), + p=Term(type=IRI, iri="http://example.org/age"), + o=Term(type=LITERAL, value="25") ), Triple( - s=Value(value="http://example.org/person1", is_uri=True), - p=Value(value="http://example.org/department", is_uri=True), - o=Value(value="Engineering", is_uri=False) + s=Term(type=IRI, iri="http://example.org/person1"), + p=Term(type=IRI, iri="http://example.org/department"), + o=Term(type=LITERAL, value="Engineering") ) ] ) @@ -181,19 +181,19 @@ class TestCassandraIntegration: metadata=Metadata(user="testuser", collection="testcol"), triples=[ Triple( - s=Value(value="http://example.org/alice", is_uri=True), - p=Value(value="http://example.org/knows", is_uri=True), - o=Value(value="http://example.org/bob", is_uri=True) + s=Term(type=IRI, iri="http://example.org/alice"), + p=Term(type=IRI, iri="http://example.org/knows"), + o=Term(type=IRI, iri="http://example.org/bob") ), Triple( - s=Value(value="http://example.org/alice", is_uri=True), - p=Value(value="http://example.org/age", is_uri=True), - o=Value(value="30", is_uri=False) + s=Term(type=IRI, iri="http://example.org/alice"), + p=Term(type=IRI, iri="http://example.org/age"), + o=Term(type=LITERAL, value="30") ), Triple( - s=Value(value="http://example.org/bob", is_uri=True), - p=Value(value="http://example.org/knows", is_uri=True), - o=Value(value="http://example.org/charlie", is_uri=True) + s=Term(type=IRI, iri="http://example.org/bob"), + p=Term(type=IRI, iri="http://example.org/knows"), + o=Term(type=IRI, iri="http://example.org/charlie") ) ] ) @@ -208,7 +208,7 @@ class TestCassandraIntegration: # Test S query (find all relationships for Alice) s_query = TriplesQueryRequest( - s=Value(value="http://example.org/alice", is_uri=True), + s=Term(type=IRI, iri="http://example.org/alice"), p=None, # None for wildcard o=None, # None for wildcard limit=10, @@ -218,18 +218,18 @@ class TestCassandraIntegration: s_results = await query_processor.query_triples(s_query) print(f"Query processor results: {len(s_results)}") for result in s_results: - print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}") + print(f" S={result.s.iri}, P={result.p.iri}, O={result.o.iri if result.o.type == IRI else result.o.value}") assert len(s_results) == 2 - - s_predicates = [t.p.value for t in s_results] + + s_predicates = [t.p.iri for t in s_results] assert "http://example.org/knows" in s_predicates assert "http://example.org/age" in s_predicates print("✓ Subject queries via processor working") - + # Test P query (find all "knows" relationships) p_query = TriplesQueryRequest( s=None, # None for wildcard - p=Value(value="http://example.org/knows", is_uri=True), + p=Term(type=IRI, iri="http://example.org/knows"), o=None, # None for wildcard limit=10, user="testuser", @@ -238,8 +238,8 @@ class TestCassandraIntegration: p_results = await query_processor.query_triples(p_query) print(p_results) assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie - - p_subjects = [t.s.value for t in p_results] + + p_subjects = [t.s.iri for t in p_results] assert "http://example.org/alice" in p_subjects assert "http://example.org/bob" in p_subjects print("✓ Predicate queries via processor working") @@ -262,19 +262,19 @@ class TestCassandraIntegration: metadata=Metadata(user="concurrent_test", collection="people"), triples=[ Triple( - s=Value(value=f"http://example.org/{person_id}", is_uri=True), - p=Value(value="http://example.org/name", is_uri=True), - o=Value(value=name, is_uri=False) + s=Term(type=IRI, iri=f"http://example.org/{person_id}"), + p=Term(type=IRI, iri="http://example.org/name"), + o=Term(type=LITERAL, value=name) ), Triple( - s=Value(value=f"http://example.org/{person_id}", is_uri=True), - p=Value(value="http://example.org/age", is_uri=True), - o=Value(value=str(age), is_uri=False) + s=Term(type=IRI, iri=f"http://example.org/{person_id}"), + p=Term(type=IRI, iri="http://example.org/age"), + o=Term(type=LITERAL, value=str(age)) ), Triple( - s=Value(value=f"http://example.org/{person_id}", is_uri=True), - p=Value(value="http://example.org/department", is_uri=True), - o=Value(value=department, is_uri=False) + s=Term(type=IRI, iri=f"http://example.org/{person_id}"), + p=Term(type=IRI, iri="http://example.org/department"), + o=Term(type=LITERAL, value=department) ) ] ) @@ -333,36 +333,36 @@ class TestCassandraIntegration: triples=[ # People and their types Triple( - s=Value(value="http://company.org/alice", is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://company.org/Employee", is_uri=True) + s=Term(type=IRI, iri="http://company.org/alice"), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri="http://company.org/Employee") ), Triple( - s=Value(value="http://company.org/bob", is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://company.org/Employee", is_uri=True) + s=Term(type=IRI, iri="http://company.org/bob"), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri="http://company.org/Employee") ), # Relationships Triple( - s=Value(value="http://company.org/alice", is_uri=True), - p=Value(value="http://company.org/reportsTo", is_uri=True), - o=Value(value="http://company.org/bob", is_uri=True) + s=Term(type=IRI, iri="http://company.org/alice"), + p=Term(type=IRI, iri="http://company.org/reportsTo"), + o=Term(type=IRI, iri="http://company.org/bob") ), Triple( - s=Value(value="http://company.org/alice", is_uri=True), - p=Value(value="http://company.org/worksIn", is_uri=True), - o=Value(value="http://company.org/engineering", is_uri=True) + s=Term(type=IRI, iri="http://company.org/alice"), + p=Term(type=IRI, iri="http://company.org/worksIn"), + o=Term(type=IRI, iri="http://company.org/engineering") ), # Personal info Triple( - s=Value(value="http://company.org/alice", is_uri=True), - p=Value(value="http://company.org/fullName", is_uri=True), - o=Value(value="Alice Johnson", is_uri=False) + s=Term(type=IRI, iri="http://company.org/alice"), + p=Term(type=IRI, iri="http://company.org/fullName"), + o=Term(type=LITERAL, value="Alice Johnson") ), Triple( - s=Value(value="http://company.org/alice", is_uri=True), - p=Value(value="http://company.org/email", is_uri=True), - o=Value(value="alice@company.org", is_uri=False) + s=Term(type=IRI, iri="http://company.org/alice"), + p=Term(type=IRI, iri="http://company.org/email"), + o=Term(type=LITERAL, value="alice@company.org") ), ] ) diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py index 30197731..13a851df 100644 --- a/tests/integration/test_import_export_graceful_shutdown.py +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -51,10 +51,10 @@ class MockWebSocket: "metadata": { "id": "test-id", "metadata": {}, - "user": "test-user", + "user": "test-user", "collection": "test-collection" }, - "triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}] + "triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}] } @@ -118,7 +118,7 @@ async def test_import_graceful_shutdown_integration(mock_backend): "user": "test-user", "collection": "test-collection" }, - "triples": [{"s": {"v": f"subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"object-{i}", "e": False}}] + "triples": [{"s": {"t": "l", "v": f"subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"object-{i}"}}] } messages.append(msg_data) @@ -163,7 +163,7 @@ async def test_export_no_message_loss_integration(mock_backend): "user": "test-user", "collection": "test-collection" }, - "triples": [{"s": {"v": f"export-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"export-object-{i}", "e": False}}] + "triples": [{"s": {"t": "l", "v": f"export-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"export-object-{i}"}}] } # Create Triples object instead of raw dict from trustgraph.schema import Triples, Metadata @@ -302,7 +302,7 @@ async def test_concurrent_import_export_shutdown(): "user": "test-user", "collection": "test-collection" }, - "triples": [{"s": {"v": f"concurrent-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}] + "triples": [{"s": {"t": "l", "v": f"concurrent-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}] } await import_handler.receive(msg) @@ -359,7 +359,7 @@ async def test_websocket_close_during_message_processing(): "user": "test-user", "collection": "test-collection" }, - "triples": [{"s": {"v": f"slow-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}] + "triples": [{"s": {"t": "l", "v": f"slow-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}] } task = asyncio.create_task(import_handler.receive(msg)) message_tasks.append(task) @@ -423,7 +423,7 @@ async def test_backpressure_during_shutdown(): # Simulate receiving and processing a message msg_data = { "metadata": {"id": f"msg-{i}"}, - "triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}] + "triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}] } await ws.send_json(msg_data) # Check if we should stop diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index dd13789f..2baa1d4d 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -15,8 +15,8 @@ from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor -from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error -from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings +from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL +from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF @@ -147,6 +147,8 @@ class TestKnowledgeGraphPipelineIntegration: processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor) processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor) processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor) + processor.triples_batch_size = 50 + processor.entity_batch_size = 5 return processor @pytest.fixture @@ -156,6 +158,7 @@ class TestKnowledgeGraphPipelineIntegration: processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor) processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor) processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor) + processor.triples_batch_size = 50 return processor @pytest.mark.asyncio @@ -253,24 +256,24 @@ class TestKnowledgeGraphPipelineIntegration: if s and o: s_uri = definitions_processor.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) - o_value = Value(value=str(o), is_uri=False) - + s_term = Term(type=IRI, iri=str(s_uri)) + o_term = Term(type=LITERAL, value=str(o)) + # Generate triples as the processor would triples.append(Triple( - s=s_value, - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=s, is_uri=False) + s=s_term, + p=Term(type=IRI, iri=RDF_LABEL), + o=Term(type=LITERAL, value=s) )) - + triples.append(Triple( - s=s_value, - p=Value(value=DEFINITION, is_uri=True), - o=o_value + s=s_term, + p=Term(type=IRI, iri=DEFINITION), + o=o_term )) - + entities.append(EntityContext( - entity=s_value, + entity=s_term, context=defn["definition"] )) @@ -279,16 +282,16 @@ class TestKnowledgeGraphPipelineIntegration: assert len(entities) == 3 # 1 entity context per entity # Verify triple structure - label_triples = [t for t in triples if t.p.value == RDF_LABEL] - definition_triples = [t for t in triples if t.p.value == DEFINITION] - + label_triples = [t for t in triples if t.p.iri == RDF_LABEL] + definition_triples = [t for t in triples if t.p.iri == DEFINITION] + assert len(label_triples) == 3 assert len(definition_triples) == 3 - + # Verify entity contexts for entity in entities: - assert entity.entity.is_uri is True - assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES) + assert entity.entity.type == IRI + assert entity.entity.iri.startswith(TRUSTGRAPH_ENTITIES) assert len(entity.context) > 0 @pytest.mark.asyncio @@ -309,52 +312,52 @@ class TestKnowledgeGraphPipelineIntegration: s = rel["subject"] p = rel["predicate"] o = rel["object"] - + if s and p and o: s_uri = relationships_processor.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) - + s_term = Term(type=IRI, iri=str(s_uri)) + p_uri = relationships_processor.to_uri(p) - p_value = Value(value=str(p_uri), is_uri=True) - + p_term = Term(type=IRI, iri=str(p_uri)) + if rel["object-entity"]: o_uri = relationships_processor.to_uri(o) - o_value = Value(value=str(o_uri), is_uri=True) + o_term = Term(type=IRI, iri=str(o_uri)) else: - o_value = Value(value=str(o), is_uri=False) - + o_term = Term(type=LITERAL, value=str(o)) + # Main relationship triple - triples.append(Triple(s=s_value, p=p_value, o=o_value)) - + triples.append(Triple(s=s_term, p=p_term, o=o_term)) + # Label triples triples.append(Triple( - s=s_value, - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=str(s), is_uri=False) + s=s_term, + p=Term(type=IRI, iri=RDF_LABEL), + o=Term(type=LITERAL, value=str(s)) )) - + triples.append(Triple( - s=p_value, - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=str(p), is_uri=False) + s=p_term, + p=Term(type=IRI, iri=RDF_LABEL), + o=Term(type=LITERAL, value=str(p)) )) - + if rel["object-entity"]: triples.append(Triple( - s=o_value, - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=str(o), is_uri=False) + s=o_term, + p=Term(type=IRI, iri=RDF_LABEL), + o=Term(type=LITERAL, value=str(o)) )) # Assert assert len(triples) > 0 # Verify relationship triples exist - relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")] + relationship_triples = [t for t in triples if t.p.iri.endswith("is_subset_of") or t.p.iri.endswith("is_used_in")] assert len(relationship_triples) >= 2 - + # Verify label triples - label_triples = [t for t in triples if t.p.value == RDF_LABEL] + label_triples = [t for t in triples if t.p.iri == RDF_LABEL] assert len(label_triples) > 0 @pytest.mark.asyncio @@ -374,9 +377,9 @@ class TestKnowledgeGraphPipelineIntegration: ), triples=[ Triple( - s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True), - p=Value(value=DEFINITION, is_uri=True), - o=Value(value="A subset of AI", is_uri=False) + s=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), + p=Term(type=IRI, iri=DEFINITION), + o=Term(type=LITERAL, value="A subset of AI") ) ] ) @@ -405,9 +408,14 @@ class TestKnowledgeGraphPipelineIntegration: collection="test_collection", metadata=[] ), - entities=[] + entities=[ + EntityEmbeddings( + entity=Term(type=IRI, iri="http://example.org/entity"), + vectors=[[0.1, 0.2, 0.3]] + ) + ] ) - + mock_msg = MagicMock() mock_msg.value.return_value = sample_embeddings @@ -496,12 +504,12 @@ class TestKnowledgeGraphPipelineIntegration: await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context) # Assert - # Should still call producers but with empty results + # Should NOT call producers with empty results (avoids Cassandra NULL issues) triples_producer = mock_flow_context("triples") entity_contexts_producer = mock_flow_context("entity-contexts") - - triples_producer.send.assert_called_once() - entity_contexts_producer.send.assert_called_once() + + triples_producer.send.assert_not_called() + entity_contexts_producer.send.assert_not_called() @pytest.mark.asyncio async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk): @@ -602,9 +610,9 @@ class TestKnowledgeGraphPipelineIntegration: collection="test_collection", metadata=[ Triple( - s=Value(value="doc:test", is_uri=True), - p=Value(value="dc:title", is_uri=True), - o=Value(value="Test Document", is_uri=False) + s=Term(type=IRI, iri="doc:test"), + p=Term(type=IRI, iri="dc:title"), + o=Term(type=LITERAL, value="Test Document") ) ] ) diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index 7b2245ce..dd48affe 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -11,7 +11,7 @@ import json import asyncio from unittest.mock import AsyncMock, MagicMock, patch -from trustgraph.extract.kg.objects.processor import Processor +from trustgraph.extract.kg.rows.processor import Processor from trustgraph.schema import ( Chunk, ExtractedObject, Metadata, RowSchema, Field, PromptRequest, PromptResponse @@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Mock flow with failing prompt service @@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration diff --git a/tests/integration/test_objects_cassandra_integration.py b/tests/integration/test_objects_cassandra_integration.py deleted file mode 100644 index 3310b396..00000000 --- a/tests/integration/test_objects_cassandra_integration.py +++ /dev/null @@ -1,608 +0,0 @@ -""" -Integration tests for Cassandra Object Storage - -These tests verify the end-to-end functionality of storing ExtractedObjects -in Cassandra, including table creation, data insertion, and error handling. -""" - -import pytest -from unittest.mock import MagicMock, AsyncMock, patch -import json -import uuid - -from trustgraph.storage.objects.cassandra.write import Processor -from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field - - -@pytest.mark.integration -class TestObjectsCassandraIntegration: - """Integration tests for Cassandra object storage""" - - @pytest.fixture - def mock_cassandra_session(self): - """Mock Cassandra session for integration tests""" - session = MagicMock() - - # Track if keyspaces have been created - created_keyspaces = set() - - # Mock the execute method to return a valid result for keyspace checks - def execute_mock(query, *args, **kwargs): - result = MagicMock() - query_str = str(query) - - # Track keyspace creation - if "CREATE KEYSPACE" in query_str: - # Extract keyspace name from query - import re - match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str) - if match: - created_keyspaces.add(match.group(1)) - - # For keyspace existence checks - if "system_schema.keyspaces" in query_str: - # Check if this keyspace was created - if args and args[0] in created_keyspaces: - result.one.return_value = MagicMock() # Exists - else: - result.one.return_value = None # Doesn't exist - else: - result.one.return_value = None - - return result - - session.execute = MagicMock(side_effect=execute_mock) - return session - - @pytest.fixture - def mock_cassandra_cluster(self, mock_cassandra_session): - """Mock Cassandra cluster""" - cluster = MagicMock() - cluster.connect.return_value = mock_cassandra_session - cluster.shutdown = MagicMock() - return cluster - - @pytest.fixture - def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session): - """Create processor with mocked Cassandra dependencies""" - processor = MagicMock() - processor.graph_host = "localhost" - processor.graph_username = None - processor.graph_password = None - processor.config_key = "schema" - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.cluster = None - processor.session = None - - # Bind actual methods - processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) - processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.create_collection = Processor.create_collection.__get__(processor, Processor) - - return processor, mock_cassandra_cluster, mock_cassandra_session - - @pytest.mark.asyncio - async def test_end_to_end_object_storage(self, processor_with_mocks): - """Test complete flow from schema config to object storage""" - processor, mock_cluster, mock_session = processor_with_mocks - - # Mock Cluster creation - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Step 1: Configure schema - config = { - "schema": { - "customer_records": json.dumps({ - "name": "customer_records", - "description": "Customer information", - "fields": [ - {"name": "customer_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string", "required": True}, - {"name": "email", "type": "string", "indexed": True}, - {"name": "age", "type": "integer"} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - assert "customer_records" in processor.schemas - - # Step 1.5: Create the collection first (simulate tg-set-collection) - await processor.create_collection("test_user", "import_2024", {}) - - # Step 2: Process an ExtractedObject - test_obj = ExtractedObject( - metadata=Metadata( - id="doc-001", - user="test_user", - collection="import_2024", - metadata=[] - ), - schema_name="customer_records", - values=[{ - "customer_id": "CUST001", - "name": "John Doe", - "email": "john@example.com", - "age": "30" - }], - confidence=0.95, - source_span="Customer: John Doe..." - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - await processor.on_object(msg, None, None) - - # Verify Cassandra interactions - assert mock_cluster.connect.called - - # Verify keyspace creation - keyspace_calls = [call for call in mock_session.execute.call_args_list - if "CREATE KEYSPACE" in str(call)] - assert len(keyspace_calls) == 1 - assert "test_user" in str(keyspace_calls[0]) - - # Verify table creation - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix - assert "collection text" in str(table_calls[0]) - assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0]) - - # Verify index creation - index_calls = [call for call in mock_session.execute.call_args_list - if "CREATE INDEX" in str(call)] - assert len(index_calls) == 1 - assert "email" in str(index_calls[0]) - - # Verify data insertion - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - insert_call = insert_calls[0] - assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix - - # Check inserted values - values = insert_call[0][1] - assert "import_2024" in values # collection - assert "CUST001" in values # customer_id - assert "John Doe" in values # name - assert "john@example.com" in values # email - assert 30 in values # age (converted to int) - - @pytest.mark.asyncio - async def test_multi_schema_handling(self, processor_with_mocks): - """Test handling multiple schemas and objects""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure multiple schemas - config = { - "schema": { - "products": json.dumps({ - "name": "products", - "fields": [ - {"name": "product_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string"}, - {"name": "price", "type": "float"} - ] - }), - "orders": json.dumps({ - "name": "orders", - "fields": [ - {"name": "order_id", "type": "string", "primary_key": True}, - {"name": "customer_id", "type": "string"}, - {"name": "total", "type": "float"} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - assert len(processor.schemas) == 2 - - # Create collections first - await processor.create_collection("shop", "catalog", {}) - await processor.create_collection("shop", "sales", {}) - - # Process objects for different schemas - product_obj = ExtractedObject( - metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), - schema_name="products", - values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], - confidence=0.9, - source_span="Product..." - ) - - order_obj = ExtractedObject( - metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), - schema_name="orders", - values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], - confidence=0.85, - source_span="Order..." - ) - - # Process both objects - for obj in [product_obj, order_obj]: - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Verify separate tables were created - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 2 - assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix - assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix - - @pytest.mark.asyncio - async def test_missing_required_fields(self, processor_with_mocks): - """Test handling of objects with missing required fields""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema with required field - processor.schemas["test_schema"] = RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True, required=True), - Field(name="required_field", type="string", size=100, required=True) - ] - ) - - # Create collection first - await processor.create_collection("test", "test", {}) - - # Create object missing required field - test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), - schema_name="test_schema", - values=[{"id": "123"}], # missing required_field - confidence=0.8, - source_span="Test" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - # Should still process (Cassandra doesn't enforce NOT NULL) - await processor.on_object(msg, None, None) - - # Verify insert was attempted - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - - @pytest.mark.asyncio - async def test_schema_without_primary_key(self, processor_with_mocks): - """Test handling schemas without defined primary keys""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema without primary key - processor.schemas["events"] = RowSchema( - name="events", - description="Event log", - fields=[ - Field(name="event_type", type="string", size=50), - Field(name="timestamp", type="timestamp", size=0) - ] - ) - - # Create collection first - await processor.create_collection("logger", "app_events", {}) - - # Process object - test_obj = ExtractedObject( - metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]), - schema_name="events", - values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}], - confidence=1.0, - source_span="Event" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - await processor.on_object(msg, None, None) - - # Verify synthetic_id was added - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "synthetic_id uuid" in str(table_calls[0]) - - # Verify insert includes UUID - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - values = insert_calls[0][0][1] - # Check that a UUID was generated (will be in values list) - uuid_found = any(isinstance(v, uuid.UUID) for v in values) - assert uuid_found - - @pytest.mark.asyncio - async def test_authentication_handling(self, processor_with_mocks): - """Test Cassandra authentication""" - processor, mock_cluster, mock_session = processor_with_mocks - processor.cassandra_username = "cassandra_user" - processor.cassandra_password = "cassandra_pass" - - with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class: - with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth: - mock_cluster_class.return_value = mock_cluster - - # Trigger connection - processor.connect_cassandra() - - # Verify authentication was configured - mock_auth.assert_called_once_with( - username="cassandra_user", - password="cassandra_pass" - ) - mock_cluster_class.assert_called_once() - call_kwargs = mock_cluster_class.call_args[1] - assert 'auth_provider' in call_kwargs - - @pytest.mark.asyncio - async def test_error_handling_during_insert(self, processor_with_mocks): - """Test error handling when insertion fails""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["test"] = RowSchema( - name="test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Make insert fail - mock_result = MagicMock() - mock_result.one.return_value = MagicMock() # Keyspace exists - mock_session.execute.side_effect = [ - mock_result, # keyspace existence check succeeds - None, # table creation succeeds - Exception("Connection timeout") # insert fails - ] - - test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), - schema_name="test", - values=[{"id": "123"}], - confidence=0.9, - source_span="Test" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - # Should raise the exception - with pytest.raises(Exception, match="Connection timeout"): - await processor.on_object(msg, None, None) - - @pytest.mark.asyncio - async def test_collection_partitioning(self, processor_with_mocks): - """Test that objects are properly partitioned by collection""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["data"] = RowSchema( - name="data", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Process objects from different collections - collections = ["import_jan", "import_feb", "import_mar"] - - # Create all collections first - for coll in collections: - await processor.create_collection("analytics", coll, {}) - - for coll in collections: - obj = ExtractedObject( - metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]), - schema_name="data", - values=[{"id": f"ID-{coll}"}], - confidence=0.9, - source_span="Data" - ) - - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Verify all inserts include collection in values - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 3 - - # Check each insert has the correct collection - for i, call in enumerate(insert_calls): - values = call[0][1] - assert collections[i] in values - - @pytest.mark.asyncio - async def test_batch_object_processing(self, processor_with_mocks): - """Test processing objects with batched values""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema - config = { - "schema": { - "batch_customers": json.dumps({ - "name": "batch_customers", - "description": "Customer batch data", - "fields": [ - {"name": "customer_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string", "required": True}, - {"name": "email", "type": "string", "indexed": True} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - - # Process batch object with multiple values - batch_obj = ExtractedObject( - metadata=Metadata( - id="batch-001", - user="test_user", - collection="batch_import", - metadata=[] - ), - schema_name="batch_customers", - values=[ - { - "customer_id": "CUST001", - "name": "John Doe", - "email": "john@example.com" - }, - { - "customer_id": "CUST002", - "name": "Jane Smith", - "email": "jane@example.com" - }, - { - "customer_id": "CUST003", - "name": "Bob Johnson", - "email": "bob@example.com" - } - ], - confidence=0.92, - source_span="Multiple customers extracted from document" - ) - - # Create collection first - await processor.create_collection("test_user", "batch_import", {}) - - msg = MagicMock() - msg.value.return_value = batch_obj - - await processor.on_object(msg, None, None) - - # Verify table creation - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "o_batch_customers" in str(table_calls[0]) - - # Verify multiple inserts for batch values - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - # Should have 3 separate inserts for the 3 objects in the batch - assert len(insert_calls) == 3 - - # Check each insert has correct data - for i, call in enumerate(insert_calls): - values = call[0][1] - assert "batch_import" in values # collection - assert f"CUST00{i+1}" in values # customer_id - if i == 0: - assert "John Doe" in values - assert "john@example.com" in values - elif i == 1: - assert "Jane Smith" in values - assert "jane@example.com" in values - elif i == 2: - assert "Bob Johnson" in values - assert "bob@example.com" in values - - @pytest.mark.asyncio - async def test_empty_batch_processing(self, processor_with_mocks): - """Test processing objects with empty values array""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["empty_test"] = RowSchema( - name="empty_test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Create collection first - await processor.create_collection("test", "empty", {}) - - # Process empty batch object - empty_obj = ExtractedObject( - metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), - schema_name="empty_test", - values=[], # Empty batch - confidence=1.0, - source_span="No objects found" - ) - - msg = MagicMock() - msg.value.return_value = empty_obj - - await processor.on_object(msg, None, None) - - # Should still create table - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - - # Should not create any insert statements for empty batch - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 0 - - @pytest.mark.asyncio - async def test_mixed_single_and_batch_objects(self, processor_with_mocks): - """Test processing mix of single and batch objects""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["mixed_test"] = RowSchema( - name="mixed_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="data", type="string", size=100) - ] - ) - - # Create collection first - await processor.create_collection("test", "mixed", {}) - - # Single object (backward compatibility) - single_obj = ExtractedObject( - metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]), - schema_name="mixed_test", - values=[{"id": "single-1", "data": "single data"}], # Array with single item - confidence=0.9, - source_span="Single object" - ) - - # Batch object - batch_obj = ExtractedObject( - metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]), - schema_name="mixed_test", - values=[ - {"id": "batch-1", "data": "batch data 1"}, - {"id": "batch-2", "data": "batch data 2"} - ], - confidence=0.85, - source_span="Batch objects" - ) - - # Process both - for obj in [single_obj, batch_obj]: - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Should have 3 total inserts (1 + 2) - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 3 \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py new file mode 100644 index 00000000..2cb973a7 --- /dev/null +++ b/tests/integration/test_rows_cassandra_integration.py @@ -0,0 +1,492 @@ +""" +Integration tests for Cassandra Row Storage (Unified Table Implementation) + +These tests verify the end-to-end functionality of storing ExtractedObjects +in the unified Cassandra rows table, including table creation, data insertion, +and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +import json + +from trustgraph.storage.rows.cassandra.write import Processor +from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +@pytest.mark.integration +class TestRowsCassandraIntegration: + """Integration tests for Cassandra row storage with unified table""" + + @pytest.fixture + def mock_cassandra_session(self): + """Mock Cassandra session for integration tests""" + session = MagicMock() + + # Track if keyspaces have been created + created_keyspaces = set() + + # Mock the execute method to return a valid result for keyspace checks + def execute_mock(query, *args, **kwargs): + result = MagicMock() + query_str = str(query) + + # Track keyspace creation + if "CREATE KEYSPACE" in query_str: + import re + match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str) + if match: + created_keyspaces.add(match.group(1)) + + # For keyspace existence checks + if "system_schema.keyspaces" in query_str: + if args and args[0] in created_keyspaces: + result.one.return_value = MagicMock() # Exists + else: + result.one.return_value = None # Doesn't exist + else: + result.one.return_value = None + + return result + + session.execute = MagicMock(side_effect=execute_mock) + return session + + @pytest.fixture + def mock_cassandra_cluster(self, mock_cassandra_session): + """Mock Cassandra cluster""" + cluster = MagicMock() + cluster.connect.return_value = mock_cassandra_session + cluster.shutdown = MagicMock() + return cluster + + @pytest.fixture + def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session): + """Create processor with mocked Cassandra dependencies""" + processor = MagicMock() + processor.cassandra_host = ["localhost"] + processor.cassandra_username = None + processor.cassandra_password = None + processor.config_key = "schema" + processor.schemas = {} + processor.known_keyspaces = set() + processor.tables_initialized = set() + processor.registered_partitions = set() + processor.cluster = None + processor.session = None + + # Bind actual methods from the new unified table implementation + processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) + processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) + processor.on_object = Processor.on_object.__get__(processor, Processor) + processor.collection_exists = MagicMock(return_value=True) + + return processor, mock_cassandra_cluster, mock_cassandra_session + + @pytest.mark.asyncio + async def test_end_to_end_object_storage(self, processor_with_mocks): + """Test complete flow from schema config to object storage""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Step 1: Configure schema + config = { + "schema": { + "customer_records": json.dumps({ + "name": "customer_records", + "description": "Customer information", + "fields": [ + {"name": "customer_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string", "required": True}, + {"name": "email", "type": "string", "indexed": True}, + {"name": "age", "type": "integer"} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + assert "customer_records" in processor.schemas + + # Step 2: Process an ExtractedObject + test_obj = ExtractedObject( + metadata=Metadata( + id="doc-001", + user="test_user", + collection="import_2024", + metadata=[] + ), + schema_name="customer_records", + values=[{ + "customer_id": "CUST001", + "name": "John Doe", + "email": "john@example.com", + "age": "30" + }], + confidence=0.95, + source_span="Customer: John Doe..." + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify Cassandra interactions + assert mock_cluster.connect.called + + # Verify keyspace creation + keyspace_calls = [call for call in mock_session.execute.call_args_list + if "CREATE KEYSPACE" in str(call)] + assert len(keyspace_calls) == 1 + assert "test_user" in str(keyspace_calls[0]) + + # Verify unified table creation (rows table, not per-schema table) + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 2 # rows table + row_partitions table + assert any("rows" in str(call) for call in table_calls) + assert any("row_partitions" in str(call) for call in table_calls) + + # Verify the rows table has correct structure + rows_table_call = [call for call in table_calls if ".rows" in str(call)][0] + assert "collection text" in str(rows_table_call) + assert "schema_name text" in str(rows_table_call) + assert "index_name text" in str(rows_table_call) + assert "data map" in str(rows_table_call) + + # Verify data insertion into unified table + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + # Should have 2 data inserts: one for customer_id (primary), one for email (indexed) + assert len(rows_insert_calls) == 2 + + @pytest.mark.asyncio + async def test_multi_schema_handling(self, processor_with_mocks): + """Test handling multiple schemas stored in unified table""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Configure multiple schemas + config = { + "schema": { + "products": json.dumps({ + "name": "products", + "fields": [ + {"name": "product_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string"}, + {"name": "price", "type": "float"} + ] + }), + "orders": json.dumps({ + "name": "orders", + "fields": [ + {"name": "order_id", "type": "string", "primary_key": True}, + {"name": "customer_id", "type": "string"}, + {"name": "total", "type": "float"} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + assert len(processor.schemas) == 2 + + # Process objects for different schemas + product_obj = ExtractedObject( + metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), + schema_name="products", + values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], + confidence=0.9, + source_span="Product..." + ) + + order_obj = ExtractedObject( + metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), + schema_name="orders", + values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], + confidence=0.85, + source_span="Order..." + ) + + # Process both objects + for obj in [product_obj, order_obj]: + msg = MagicMock() + msg.value.return_value = obj + await processor.on_object(msg, None, None) + + # All data goes into the same unified rows table + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + # Should only create 2 tables: rows + row_partitions (not per-schema tables) + assert len(table_calls) == 2 + + # Verify data inserts go to unified rows table + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) > 0 + for call in rows_insert_calls: + assert ".rows" in str(call) + + @pytest.mark.asyncio + async def test_multi_index_storage(self, processor_with_mocks): + """Test that rows are stored with multiple indexes""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Schema with multiple indexed fields + processor.schemas["indexed_data"] = RowSchema( + name="indexed_data", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True), + Field(name="status", type="string", size=50, indexed=True), + Field(name="description", type="string", size=200) # Not indexed + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + schema_name="indexed_data", + values=[{ + "id": "123", + "category": "electronics", + "status": "active", + "description": "A product" + }], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Should have 3 data inserts (one per indexed field: id, category, status) + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 3 + + # Verify different index names were used + index_names = set() + for call in rows_insert_calls: + values = call[0][1] + index_names.add(values[2]) # index_name is 3rd parameter + + assert index_names == {"id", "category", "status"} + + @pytest.mark.asyncio + async def test_authentication_handling(self, processor_with_mocks): + """Test Cassandra authentication""" + processor, mock_cluster, mock_session = processor_with_mocks + processor.cassandra_username = "cassandra_user" + processor.cassandra_password = "cassandra_pass" + + with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class: + with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth: + mock_cluster_class.return_value = mock_cluster + + # Trigger connection + processor.connect_cassandra() + + # Verify authentication was configured + mock_auth.assert_called_once_with( + username="cassandra_user", + password="cassandra_pass" + ) + mock_cluster_class.assert_called_once() + call_kwargs = mock_cluster_class.call_args[1] + assert 'auth_provider' in call_kwargs + + @pytest.mark.asyncio + async def test_batch_object_processing(self, processor_with_mocks): + """Test processing objects with batched values""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Configure schema + config = { + "schema": { + "batch_customers": json.dumps({ + "name": "batch_customers", + "description": "Customer batch data", + "fields": [ + {"name": "customer_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string", "required": True}, + {"name": "email", "type": "string", "indexed": True} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + + # Process batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_import", + metadata=[] + ), + schema_name="batch_customers", + values=[ + { + "customer_id": "CUST001", + "name": "John Doe", + "email": "john@example.com" + }, + { + "customer_id": "CUST002", + "name": "Jane Smith", + "email": "jane@example.com" + }, + { + "customer_id": "CUST003", + "name": "Bob Johnson", + "email": "bob@example.com" + } + ], + confidence=0.92, + source_span="Multiple customers extracted from document" + ) + + msg = MagicMock() + msg.value.return_value = batch_obj + + await processor.on_object(msg, None, None) + + # Verify unified table creation + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 2 # rows + row_partitions + + # Each row in batch gets 2 data inserts (customer_id primary + email indexed) + # 3 rows * 2 indexes = 6 data inserts + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 6 + + @pytest.mark.asyncio + async def test_empty_batch_processing(self, processor_with_mocks): + """Test processing objects with empty values array""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["empty_test"] = RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + + # Process empty batch object + empty_obj = ExtractedObject( + metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), + schema_name="empty_test", + values=[], # Empty batch + confidence=1.0, + source_span="No objects found" + ) + + msg = MagicMock() + msg.value.return_value = empty_obj + + await processor.on_object(msg, None, None) + + # Should not create any data insert statements for empty batch + # (partition registration may still happen) + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 0 + + @pytest.mark.asyncio + async def test_data_stored_as_map(self, processor_with_mocks): + """Test that data is stored as map""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["map_test"] = RowSchema( + name="map_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="count", type="integer", size=0) + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + schema_name="map_test", + values=[{"id": "123", "name": "Test Item", "count": "42"}], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify insert uses map for data + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) >= 1 + + # Check that data is passed as a dict (will be map in Cassandra) + insert_call = rows_insert_calls[0] + values = insert_call[0][1] + # Values are: (collection, schema_name, index_name, index_value, data, source) + # values[4] should be the data map + data_map = values[4] + assert isinstance(data_map, dict) + assert data_map["id"] == "123" + assert data_map["name"] == "Test Item" + assert data_map["count"] == "42" + + @pytest.mark.asyncio + async def test_partition_registration(self, processor_with_mocks): + """Test that partitions are registered for efficient querying""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["partition_test"] = RowSchema( + name="partition_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True) + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]), + schema_name="partition_test", + values=[{"id": "123", "category": "test"}], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify partition registration + partition_inserts = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and "row_partitions" in str(call)] + # Should register partitions for each index (id, category) + assert len(partition_inserts) == 2 + + # Verify cache was updated + assert ("my_collection", "partition_test") in processor.registered_partitions diff --git a/tests/integration/test_objects_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py similarity index 98% rename from tests/integration/test_objects_graphql_query_integration.py rename to tests/integration/test_rows_graphql_query_integration.py index 13b12532..a717901b 100644 --- a/tests/integration/test_objects_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -1,5 +1,5 @@ """ -Integration tests for Objects GraphQL Query Service +Integration tests for Rows GraphQL Query Service These tests verify end-to-end functionality including: - Real Cassandra database operations @@ -24,8 +24,8 @@ except Exception: DOCKER_AVAILABLE = False CassandraContainer = None -from trustgraph.query.objects.cassandra.service import Processor -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata @@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration: processor.connect_cassandra() # Create mock message - request = ObjectsQueryRequest( + request = RowsQueryRequest( user="msg_test_user", collection="msg_test_collection", query='{ customer_objects { customer_id name } }', @@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration: # Verify response structure sent_response = mock_response_producer.send.call_args[0][0] - assert isinstance(sent_response, ObjectsQueryResponse) + assert isinstance(sent_response, RowsQueryResponse) # Should have no system error (even if no data) assert sent_response.error is None diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index cf8037d0..d5fb5672 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -2,7 +2,7 @@ Integration tests for Structured Query Service These tests verify the end-to-end functionality of the structured query service, -testing orchestration between nlp-query and objects-query services. +testing orchestration between nlp-query and rows-query services. Following the TEST_STRATEGY.md approach for integration testing. """ @@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.schema import ( StructuredQueryRequest, StructuredQueryResponse, QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, - ObjectsQueryRequest, ObjectsQueryResponse, + RowsQueryRequest, RowsQueryResponse, Error, GraphQLError ) from trustgraph.retrieval.structured_query.service import Processor @@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects Query Service Response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}', errors=None, @@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration: # Verify Objects service call mock_objects_client.request.assert_called_once() objects_call_args = mock_objects_client.request.call_args[0][0] - assert isinstance(objects_call_args, ObjectsQueryRequest) + assert isinstance(objects_call_args, RowsQueryRequest) assert "customers" in objects_call_args.query assert "orders" in objects_call_args.query assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string @@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects service failure - objects_error_response = ObjectsQueryResponse( + objects_error_response = RowsQueryResponse( error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"), data=None, errors=None, @@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration: assert response.error is not None assert response.error.type == "structured-query-error" - assert "Objects query service error" in response.error.message + assert "Rows query service error" in response.error.message assert "nonexistent_table" in response.error.message @pytest.mark.asyncio @@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration: ) ] - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, # No data when validation fails errors=validation_errors, @@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration: ] } - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=json.dumps(complex_data), errors=None, @@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock empty Objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": []}', # Empty result set errors=None, @@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration: confidence=0.9 ) - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=f'{{"test_{i}": [{{"id": "{i}"}}]}}', errors=None, @@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration: if service_name == "nlp-query-request": service_call_count += 1 return nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": service_call_count += 1 return objects_client elif service_name == "response": @@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}', errors=None, @@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response diff --git a/tests/pytest.ini b/tests/pytest.ini index b763299c..b032a9d4 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -19,4 +19,5 @@ markers = integration: marks tests as integration tests unit: marks tests as unit tests contract: marks tests as contract tests (service interface validation) - vertexai: marks tests as vertex ai specific tests \ No newline at end of file + vertexai: marks tests as vertex ai specific tests + asyncio: marks tests that use asyncio \ No newline at end of file diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py index ea5d04cc..0587e3d6 100644 --- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -88,8 +88,13 @@ async def test_subscriber_deferred_acknowledgment_success(): @pytest.mark.asyncio -async def test_subscriber_deferred_acknowledgment_failure(): - """Verify Subscriber negative acks on delivery failure.""" +async def test_subscriber_dropped_message_still_acks(): + """Verify Subscriber acks even when message is dropped (backpressure). + + This prevents redelivery storms on shared topics - if we negative_ack + a dropped message, it gets redelivered to all subscribers, none of + whom can handle it either, causing a tight redelivery loop. + """ mock_backend = MagicMock() mock_consumer = MagicMock() mock_backend.create_consumer.return_value = mock_consumer @@ -103,24 +108,66 @@ async def test_subscriber_deferred_acknowledgment_failure(): max_size=1, # Very small queue backpressure_strategy="drop_new" ) - + # Start subscriber to initialize consumer await subscriber.start() - + # Create queue and fill it queue = await subscriber.subscribe("test-queue") await queue.put({"existing": "data"}) - - # Create mock message - should be dropped - msg = create_mock_message("msg-1", {"data": "test"}) - - # Process message (should fail due to full queue + drop_new strategy) + + # Create mock message - should be dropped due to full queue + msg = create_mock_message("test-queue", {"data": "test"}) + + # Process message (should be dropped due to full queue + drop_new strategy) await subscriber._process_message(msg) - - # Should negative acknowledge failed delivery - mock_consumer.negative_acknowledge.assert_called_once_with(msg) - mock_consumer.acknowledge.assert_not_called() - + + # Should acknowledge even though delivery failed - prevents redelivery storm + mock_consumer.acknowledge.assert_called_once_with(msg) + mock_consumer.negative_acknowledge.assert_not_called() + + # Clean up + await subscriber.stop() + + +@pytest.mark.asyncio +async def test_subscriber_orphaned_message_acks(): + """Verify Subscriber acks orphaned messages (no matching waiter). + + On shared response topics, if a message arrives for a waiter that + no longer exists (e.g., client disconnected, request timed out), + we must acknowledge it to prevent redelivery storms. + """ + mock_backend = MagicMock() + mock_consumer = MagicMock() + mock_backend.create_consumer.return_value = mock_consumer + + subscriber = Subscriber( + backend=mock_backend, + topic="test-topic", + subscription="test-subscription", + consumer_name="test-consumer", + schema=dict, + max_size=10, + backpressure_strategy="block" + ) + + # Start subscriber to initialize consumer + await subscriber.start() + + # Don't create any queues - message will be orphaned + # This simulates a response arriving after the waiter has unsubscribed + + # Create mock message with an ID that has no matching waiter + msg = create_mock_message("non-existent-waiter-id", {"data": "orphaned"}) + + # Process message (should be orphaned - no matching waiter) + await subscriber._process_message(msg) + + # Should acknowledge orphaned message - prevents redelivery storm + mock_consumer.acknowledge.assert_called_once_with(msg) + mock_consumer.negative_acknowledge.assert_not_called() + # Clean up await subscriber.stop() diff --git a/tests/unit/test_cli/test_tool_commands.py b/tests/unit/test_cli/test_tool_commands.py index 913fe416..9c204614 100644 --- a/tests/unit/test_cli/test_tool_commands.py +++ b/tests/unit/test_cli/test_tool_commands.py @@ -55,6 +55,9 @@ class TestSetToolStructuredQuery: mcp_tool=None, collection="sales_data", template=None, + schema_name=None, + index_name=None, + limit=None, arguments=[], group=None, state=None, @@ -92,6 +95,9 @@ class TestSetToolStructuredQuery: mcp_tool=None, collection=None, # No collection specified template=None, + schema_name=None, + index_name=None, + limit=None, arguments=[], group=None, state=None, @@ -132,6 +138,9 @@ class TestSetToolStructuredQuery: mcp_tool=None, collection='sales_data', template=None, + schema_name=None, + index_name=None, + limit=None, arguments=[], group=None, state=None, @@ -201,6 +210,144 @@ class TestSetToolStructuredQuery: assert 'Exception:' in printed_output or 'invalid choice:' in printed_output.lower() +class TestSetToolRowEmbeddingsQuery: + """Test the set_tool function with row-embeddings-query type.""" + + @patch('trustgraph.cli.set_tool.Api') + def test_set_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys): + """Test setting a row-embeddings-query tool with all parameters.""" + mock_api_class.return_value, mock_config = mock_api + mock_config.get.return_value = [] + + set_tool( + url="http://test.com", + id="customer_search", + name="find_customer", + description="Find customers by name using semantic search", + type="row-embeddings-query", + mcp_tool=None, + collection="sales", + template=None, + schema_name="customers", + index_name="full_name", + limit=20, + arguments=[], + group=None, + state=None, + applicable_states=None + ) + + captured = capsys.readouterr() + assert "Tool set." in captured.out + + # Verify the tool was stored correctly + call_args = mock_config.put.call_args[0][0] + assert len(call_args) == 1 + config_value = call_args[0] + assert config_value.type == "tool" + assert config_value.key == "customer_search" + + stored_tool = json.loads(config_value.value) + assert stored_tool["name"] == "find_customer" + assert stored_tool["type"] == "row-embeddings-query" + assert stored_tool["collection"] == "sales" + assert stored_tool["schema-name"] == "customers" + assert stored_tool["index-name"] == "full_name" + assert stored_tool["limit"] == 20 + + @patch('trustgraph.cli.set_tool.Api') + def test_set_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys): + """Test setting row-embeddings-query tool with minimal parameters.""" + mock_api_class.return_value, mock_config = mock_api + mock_config.get.return_value = [] + + set_tool( + url="http://test.com", + id="product_search", + name="find_product", + description="Find products using semantic search", + type="row-embeddings-query", + mcp_tool=None, + collection=None, + template=None, + schema_name="products", + index_name=None, # No index filter + limit=None, # Use default + arguments=[], + group=None, + state=None, + applicable_states=None + ) + + captured = capsys.readouterr() + assert "Tool set." in captured.out + + call_args = mock_config.put.call_args[0][0] + stored_tool = json.loads(call_args[0].value) + assert stored_tool["type"] == "row-embeddings-query" + assert stored_tool["schema-name"] == "products" + assert "index-name" not in stored_tool # Should not be included if None + assert "limit" not in stored_tool # Should not be included if None + assert "collection" not in stored_tool # Should not be included if None + + def test_set_main_row_embeddings_query_with_all_options(self): + """Test set main() with row-embeddings-query tool type and all options.""" + test_args = [ + 'tg-set-tool', + '--id', 'customer_search', + '--name', 'find_customer', + '--type', 'row-embeddings-query', + '--description', 'Find customers by name', + '--schema-name', 'customers', + '--collection', 'sales', + '--index-name', 'full_name', + '--limit', '25', + '--api-url', 'http://custom.com' + ] + + with patch('sys.argv', test_args), \ + patch('trustgraph.cli.set_tool.set_tool') as mock_set: + + set_main() + + mock_set.assert_called_once_with( + url='http://custom.com', + id='customer_search', + name='find_customer', + description='Find customers by name', + type='row-embeddings-query', + mcp_tool=None, + collection='sales', + template=None, + schema_name='customers', + index_name='full_name', + limit=25, + arguments=[], + group=None, + state=None, + applicable_states=None, + token=None + ) + + def test_valid_types_includes_row_embeddings_query(self): + """Test that 'row-embeddings-query' is included in valid tool types.""" + test_args = [ + 'tg-set-tool', + '--id', 'test_tool', + '--name', 'test_tool', + '--type', 'row-embeddings-query', + '--description', 'Test tool', + '--schema-name', 'test_schema' + ] + + with patch('sys.argv', test_args), \ + patch('trustgraph.cli.set_tool.set_tool') as mock_set: + + # Should not raise an exception about invalid type + set_main() + mock_set.assert_called_once() + + class TestShowToolsStructuredQuery: """Test the show_tools function with structured-query tools.""" @@ -259,9 +406,9 @@ class TestShowToolsStructuredQuery: @patch('trustgraph.cli.show_tools.Api') def test_show_mixed_tool_types(self, mock_api_class, mock_api, capsys): - """Test displaying multiple tool types including structured-query.""" + """Test displaying multiple tool types including structured-query and row-embeddings-query.""" mock_api_class.return_value, mock_config = mock_api - + tools = [ { "name": "ask_knowledge", @@ -270,37 +417,47 @@ class TestShowToolsStructuredQuery: "collection": "docs" }, { - "name": "query_data", + "name": "query_data", "description": "Query structured data", "type": "structured-query", "collection": "sales" }, + { + "name": "find_customer", + "description": "Find customers by semantic search", + "type": "row-embeddings-query", + "schema-name": "customers", + "collection": "crm" + }, { "name": "complete_text", "description": "Generate text", "type": "text-completion" } ] - + config_values = [ ConfigValue(type="tool", key=f"tool_{i}", value=json.dumps(tool)) for i, tool in enumerate(tools) ] mock_config.get_values.return_value = config_values - + show_config("http://test.com") - + captured = capsys.readouterr() output = captured.out - + # All tool types should be displayed assert "knowledge-query" in output - assert "structured-query" in output + assert "structured-query" in output + assert "row-embeddings-query" in output assert "text-completion" in output - + # Collections should be shown for appropriate tools assert "docs" in output # knowledge-query collection assert "sales" in output # structured-query collection + assert "crm" in output # row-embeddings-query collection + assert "customers" in output # row-embeddings-query schema-name def test_show_main_parses_args_correctly(self): """Test that show main() parses arguments correctly.""" @@ -317,6 +474,76 @@ class TestShowToolsStructuredQuery: mock_show.assert_called_once_with(url='http://custom.com', token=None) +class TestShowToolsRowEmbeddingsQuery: + """Test the show_tools function with row-embeddings-query tools.""" + + @patch('trustgraph.cli.show_tools.Api') + def test_show_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys): + """Test displaying a row-embeddings-query tool with all fields.""" + mock_api_class.return_value, mock_config = mock_api + + tool_config = { + "name": "find_customer", + "description": "Find customers by name using semantic search", + "type": "row-embeddings-query", + "collection": "sales", + "schema-name": "customers", + "index-name": "full_name", + "limit": 20 + } + + config_value = ConfigValue( + type="tool", + key="customer_search", + value=json.dumps(tool_config) + ) + mock_config.get_values.return_value = [config_value] + + show_config("http://test.com") + + captured = capsys.readouterr() + output = captured.out + + # Check that tool information is displayed + assert "customer_search" in output + assert "find_customer" in output + assert "row-embeddings-query" in output + assert "sales" in output # Collection + assert "customers" in output # Schema name + assert "full_name" in output # Index name + assert "20" in output # Limit + + @patch('trustgraph.cli.show_tools.Api') + def test_show_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys): + """Test displaying row-embeddings-query tool with minimal fields.""" + mock_api_class.return_value, mock_config = mock_api + + tool_config = { + "name": "find_product", + "description": "Find products using semantic search", + "type": "row-embeddings-query", + "schema-name": "products" + # No collection, index-name, or limit + } + + config_value = ConfigValue( + type="tool", + key="product_search", + value=json.dumps(tool_config) + ) + mock_config.get_values.return_value = [config_value] + + show_config("http://test.com") + + captured = capsys.readouterr() + output = captured.out + + # Should display the tool with schema-name + assert "product_search" in output + assert "row-embeddings-query" in output + assert "products" in output # Schema name + + class TestStructuredQueryToolValidation: """Test validation specific to structured-query tools.""" diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index e0ad9339..96c9c427 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock from unittest.mock import call from trustgraph.cores.knowledge import KnowledgeManager -from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Value, EntityEmbeddings +from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL @pytest.fixture @@ -71,15 +71,15 @@ def sample_triples(): return Triples( metadata=Metadata( id="test-doc-id", - user="test-user", + user="test-user", collection="default", # This should be overridden metadata=[] ), triples=[ Triple( - s=Value(value="http://example.org/john", is_uri=True), - p=Value(value="http://example.org/name", is_uri=True), - o=Value(value="John Smith", is_uri=False) + s=Term(type=IRI, iri="http://example.org/john"), + p=Term(type=IRI, iri="http://example.org/name"), + o=Term(type=LITERAL, value="John Smith") ) ] ) @@ -97,7 +97,7 @@ def sample_graph_embeddings(): ), entities=[ EntityEmbeddings( - entity=Value(value="http://example.org/john", is_uri=True), + entity=Term(type=IRI, iri="http://example.org/john"), vectors=[[0.1, 0.2, 0.3]] ) ] diff --git a/tests/unit/test_direct/test_entity_centric_kg.py b/tests/unit/test_direct/test_entity_centric_kg.py new file mode 100644 index 00000000..5f64b581 --- /dev/null +++ b/tests/unit/test_direct/test_entity_centric_kg.py @@ -0,0 +1,599 @@ +""" +Unit tests for EntityCentricKnowledgeGraph class + +Tests the entity-centric knowledge graph implementation without requiring +an actual Cassandra connection. Uses mocking to verify correct behavior. +""" + +import pytest +from unittest.mock import MagicMock, patch, call +import os + + +class TestEntityCentricKnowledgeGraph: + """Test cases for EntityCentricKnowledgeGraph""" + + @pytest.fixture + def mock_cluster(self): + """Create a mock Cassandra cluster""" + with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cluster_cls: + mock_cluster = MagicMock() + mock_session = MagicMock() + mock_cluster.connect.return_value = mock_session + mock_cluster_cls.return_value = mock_cluster + yield mock_cluster_cls, mock_cluster, mock_session + + @pytest.fixture + def entity_kg(self, mock_cluster): + """Create an EntityCentricKnowledgeGraph instance with mocked Cassandra""" + from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph + mock_cluster_cls, mock_cluster, mock_session = mock_cluster + + # Create instance + kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace') + return kg, mock_session + + def test_init_creates_entity_centric_schema(self, mock_cluster): + """Test that initialization creates the 2-table entity-centric schema""" + from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph + mock_cluster_cls, mock_cluster, mock_session = mock_cluster + + kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace') + + # Verify schema tables were created + execute_calls = mock_session.execute.call_args_list + executed_statements = [str(c) for c in execute_calls] + + # Check for keyspace creation + keyspace_created = any('create keyspace' in str(c).lower() for c in execute_calls) + assert keyspace_created + + # Check for quads_by_entity table + entity_table_created = any('quads_by_entity' in str(c) for c in execute_calls) + assert entity_table_created + + # Check for quads_by_collection table + collection_table_created = any('quads_by_collection' in str(c) for c in execute_calls) + assert collection_table_created + + # Check for collection_metadata table + metadata_table_created = any('collection_metadata' in str(c) for c in execute_calls) + assert metadata_table_created + + def test_prepare_statements_initialized(self, entity_kg): + """Test that prepared statements are initialized""" + kg, mock_session = entity_kg + + # Verify prepare was called for various statements + assert mock_session.prepare.called + prepare_calls = mock_session.prepare.call_args_list + + # Check that key prepared statements exist + prepared_queries = [str(c) for c in prepare_calls] + + # Insert statements + insert_entity_stmt = any('INSERT INTO' in str(c) and 'quads_by_entity' in str(c) + for c in prepare_calls) + assert insert_entity_stmt + + insert_collection_stmt = any('INSERT INTO' in str(c) and 'quads_by_collection' in str(c) + for c in prepare_calls) + assert insert_collection_stmt + + def test_insert_uri_object_creates_4_entity_rows(self, entity_kg): + """Test that inserting a quad with URI object creates 4 entity rows""" + kg, mock_session = entity_kg + + # Reset mocks to track only insert-related calls + mock_session.reset_mock() + + kg.insert( + collection='test_collection', + s='http://example.org/Alice', + p='http://example.org/knows', + o='http://example.org/Bob', + g='http://example.org/graph1', + otype='u' + ) + + # Verify batch was executed + mock_session.execute.assert_called() + + def test_insert_literal_object_creates_3_entity_rows(self, entity_kg): + """Test that inserting a quad with literal object creates 3 entity rows""" + kg, mock_session = entity_kg + + mock_session.reset_mock() + + kg.insert( + collection='test_collection', + s='http://example.org/Alice', + p='http://www.w3.org/2000/01/rdf-schema#label', + o='Alice Smith', + g=None, + otype='l', + dtype='xsd:string', + lang='en' + ) + + # Verify batch was executed + mock_session.execute.assert_called() + + def test_insert_default_graph(self, entity_kg): + """Test that None graph is stored as empty string""" + kg, mock_session = entity_kg + + mock_session.reset_mock() + + kg.insert( + collection='test_collection', + s='http://example.org/Alice', + p='http://example.org/knows', + o='http://example.org/Bob', + g=None, + otype='u' + ) + + mock_session.execute.assert_called() + + def test_insert_auto_detects_otype(self, entity_kg): + """Test that otype is auto-detected when not provided""" + kg, mock_session = entity_kg + + mock_session.reset_mock() + + # URI should be auto-detected + kg.insert( + collection='test_collection', + s='http://example.org/Alice', + p='http://example.org/knows', + o='http://example.org/Bob' + ) + mock_session.execute.assert_called() + + mock_session.reset_mock() + + # Literal should be auto-detected + kg.insert( + collection='test_collection', + s='http://example.org/Alice', + p='http://example.org/name', + o='Alice' + ) + mock_session.execute.assert_called() + + def test_get_s_returns_quads_for_subject(self, entity_kg): + """Test get_s queries by subject""" + kg, mock_session = entity_kg + + # Mock the query result + mock_result = [ + MagicMock(p='http://example.org/knows', o='http://example.org/Bob', + d='', otype='u', dtype='', lang='', s='http://example.org/Alice') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_s('test_collection', 'http://example.org/Alice') + + # Verify query was executed + mock_session.execute.assert_called() + + # Results should be QuadResult objects + assert len(results) == 1 + assert results[0].s == 'http://example.org/Alice' + assert results[0].p == 'http://example.org/knows' + assert results[0].o == 'http://example.org/Bob' + + def test_get_p_returns_quads_for_predicate(self, entity_kg): + """Test get_p queries by predicate""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(s='http://example.org/Alice', o='http://example.org/Bob', + d='', otype='u', dtype='', lang='', p='http://example.org/knows') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_p('test_collection', 'http://example.org/knows') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_o_returns_quads_for_object(self, entity_kg): + """Test get_o queries by object""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(s='http://example.org/Alice', p='http://example.org/knows', + d='', otype='u', dtype='', lang='', o='http://example.org/Bob') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_o('test_collection', 'http://example.org/Bob') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_sp_returns_quads_for_subject_predicate(self, entity_kg): + """Test get_sp queries by subject and predicate""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(o='http://example.org/Bob', d='', otype='u', dtype='', lang='') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_sp('test_collection', 'http://example.org/Alice', + 'http://example.org/knows') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_po_returns_quads_for_predicate_object(self, entity_kg): + """Test get_po queries by predicate and object""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(s='http://example.org/Alice', d='', otype='u', dtype='', lang='', + o='http://example.org/Bob') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_po('test_collection', 'http://example.org/knows', + 'http://example.org/Bob') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_os_returns_quads_for_object_subject(self, entity_kg): + """Test get_os queries by object and subject""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(p='http://example.org/knows', d='', otype='u', dtype='', lang='', + s='http://example.org/Alice', o='http://example.org/Bob') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_os('test_collection', 'http://example.org/Bob', + 'http://example.org/Alice') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_spo_returns_quads_for_subject_predicate_object(self, entity_kg): + """Test get_spo queries by subject, predicate, and object""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(d='', otype='u', dtype='', lang='', + o='http://example.org/Bob') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_spo('test_collection', 'http://example.org/Alice', + 'http://example.org/knows', 'http://example.org/Bob') + + mock_session.execute.assert_called() + assert len(results) == 1 + + def test_get_g_returns_quads_for_graph(self, entity_kg): + """Test get_g queries by graph""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(s='http://example.org/Alice', p='http://example.org/knows', + o='http://example.org/Bob', otype='u', dtype='', lang='') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_g('test_collection', 'http://example.org/graph1') + + mock_session.execute.assert_called() + + def test_get_all_returns_all_quads_in_collection(self, entity_kg): + """Test get_all returns all quads""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows', + o='http://example.org/Bob', otype='u', dtype='', lang='') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_all('test_collection') + + mock_session.execute.assert_called() + + def test_graph_wildcard_returns_all_graphs(self, entity_kg): + """Test that g='*' returns quads from all graphs""" + from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(p='http://example.org/knows', d='http://example.org/graph1', + otype='u', dtype='', lang='', s='http://example.org/Alice', + o='http://example.org/Bob'), + MagicMock(p='http://example.org/knows', d='http://example.org/graph2', + otype='u', dtype='', lang='', s='http://example.org/Alice', + o='http://example.org/Charlie') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD) + + # Should return quads from both graphs + assert len(results) == 2 + + def test_specific_graph_filters_results(self, entity_kg): + """Test that specifying a graph filters results""" + kg, mock_session = entity_kg + + mock_result = [ + MagicMock(p='http://example.org/knows', d='http://example.org/graph1', + otype='u', dtype='', lang='', s='http://example.org/Alice', + o='http://example.org/Bob'), + MagicMock(p='http://example.org/knows', d='http://example.org/graph2', + otype='u', dtype='', lang='', s='http://example.org/Alice', + o='http://example.org/Charlie') + ] + mock_session.execute.return_value = mock_result + + results = kg.get_s('test_collection', 'http://example.org/Alice', + g='http://example.org/graph1') + + # Should only return quads from graph1 + assert len(results) == 1 + assert results[0].g == 'http://example.org/graph1' + + def test_collection_exists_returns_true_when_exists(self, entity_kg): + """Test collection_exists returns True for existing collection""" + kg, mock_session = entity_kg + + mock_result = [MagicMock(collection='test_collection')] + mock_session.execute.return_value = mock_result + + exists = kg.collection_exists('test_collection') + + assert exists is True + + def test_collection_exists_returns_false_when_not_exists(self, entity_kg): + """Test collection_exists returns False for non-existing collection""" + kg, mock_session = entity_kg + + mock_session.execute.return_value = [] + + exists = kg.collection_exists('nonexistent_collection') + + assert exists is False + + def test_create_collection_inserts_metadata(self, entity_kg): + """Test create_collection inserts metadata row""" + kg, mock_session = entity_kg + + mock_session.reset_mock() + kg.create_collection('test_collection') + + # Verify INSERT was executed for collection_metadata + mock_session.execute.assert_called() + + def test_delete_collection_removes_all_data(self, entity_kg): + """Test delete_collection removes entity partitions and collection rows""" + kg, mock_session = entity_kg + + # Mock reading quads from collection + mock_quads = [ + MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows', + o='http://example.org/Bob', otype='u') + ] + mock_session.execute.return_value = mock_quads + + mock_session.reset_mock() + kg.delete_collection('test_collection') + + # Verify delete operations were executed + assert mock_session.execute.called + + def test_close_shuts_down_connections(self, entity_kg): + """Test close shuts down session and cluster""" + kg, mock_session = entity_kg + + kg.close() + + mock_session.shutdown.assert_called_once() + kg.cluster.shutdown.assert_called_once() + + +class TestQuadResult: + """Test cases for QuadResult class""" + + def test_quad_result_stores_all_fields(self): + """Test QuadResult stores all quad fields""" + from trustgraph.direct.cassandra_kg import QuadResult + + result = QuadResult( + s='http://example.org/Alice', + p='http://example.org/knows', + o='http://example.org/Bob', + g='http://example.org/graph1', + otype='u', + dtype='', + lang='' + ) + + assert result.s == 'http://example.org/Alice' + assert result.p == 'http://example.org/knows' + assert result.o == 'http://example.org/Bob' + assert result.g == 'http://example.org/graph1' + assert result.otype == 'u' + assert result.dtype == '' + assert result.lang == '' + + def test_quad_result_defaults(self): + """Test QuadResult default values""" + from trustgraph.direct.cassandra_kg import QuadResult + + result = QuadResult( + s='http://example.org/s', + p='http://example.org/p', + o='literal value', + g='' + ) + + assert result.otype == 'u' # Default otype + assert result.dtype == '' + assert result.lang == '' + + def test_quad_result_with_literal_metadata(self): + """Test QuadResult with literal metadata""" + from trustgraph.direct.cassandra_kg import QuadResult + + result = QuadResult( + s='http://example.org/Alice', + p='http://www.w3.org/2000/01/rdf-schema#label', + o='Alice Smith', + g='', + otype='l', + dtype='xsd:string', + lang='en' + ) + + assert result.otype == 'l' + assert result.dtype == 'xsd:string' + assert result.lang == 'en' + + +class TestWriteHelperFunctions: + """Test cases for helper functions in write.py""" + + def test_get_term_otype_for_iri(self): + """Test get_term_otype returns 'u' for IRI terms""" + from trustgraph.storage.triples.cassandra.write import get_term_otype + from trustgraph.schema import Term, IRI + + term = Term(type=IRI, iri='http://example.org/Alice') + assert get_term_otype(term) == 'u' + + def test_get_term_otype_for_literal(self): + """Test get_term_otype returns 'l' for LITERAL terms""" + from trustgraph.storage.triples.cassandra.write import get_term_otype + from trustgraph.schema import Term, LITERAL + + term = Term(type=LITERAL, value='Alice Smith') + assert get_term_otype(term) == 'l' + + def test_get_term_otype_for_blank(self): + """Test get_term_otype returns 'u' for BLANK terms""" + from trustgraph.storage.triples.cassandra.write import get_term_otype + from trustgraph.schema import Term, BLANK + + term = Term(type=BLANK, id='_:b1') + assert get_term_otype(term) == 'u' + + def test_get_term_otype_for_triple(self): + """Test get_term_otype returns 't' for TRIPLE terms""" + from trustgraph.storage.triples.cassandra.write import get_term_otype + from trustgraph.schema import Term, TRIPLE + + term = Term(type=TRIPLE) + assert get_term_otype(term) == 't' + + def test_get_term_otype_for_none(self): + """Test get_term_otype returns 'u' for None""" + from trustgraph.storage.triples.cassandra.write import get_term_otype + + assert get_term_otype(None) == 'u' + + def test_get_term_dtype_for_literal(self): + """Test get_term_dtype extracts datatype from LITERAL""" + from trustgraph.storage.triples.cassandra.write import get_term_dtype + from trustgraph.schema import Term, LITERAL + + term = Term(type=LITERAL, value='42', datatype='xsd:integer') + assert get_term_dtype(term) == 'xsd:integer' + + def test_get_term_dtype_for_non_literal(self): + """Test get_term_dtype returns empty string for non-LITERAL""" + from trustgraph.storage.triples.cassandra.write import get_term_dtype + from trustgraph.schema import Term, IRI + + term = Term(type=IRI, iri='http://example.org/Alice') + assert get_term_dtype(term) == '' + + def test_get_term_dtype_for_none(self): + """Test get_term_dtype returns empty string for None""" + from trustgraph.storage.triples.cassandra.write import get_term_dtype + + assert get_term_dtype(None) == '' + + def test_get_term_lang_for_literal(self): + """Test get_term_lang extracts language from LITERAL""" + from trustgraph.storage.triples.cassandra.write import get_term_lang + from trustgraph.schema import Term, LITERAL + + term = Term(type=LITERAL, value='Alice Smith', language='en') + assert get_term_lang(term) == 'en' + + def test_get_term_lang_for_non_literal(self): + """Test get_term_lang returns empty string for non-LITERAL""" + from trustgraph.storage.triples.cassandra.write import get_term_lang + from trustgraph.schema import Term, IRI + + term = Term(type=IRI, iri='http://example.org/Alice') + assert get_term_lang(term) == '' + + +class TestServiceHelperFunctions: + """Test cases for helper functions in service.py""" + + def test_create_term_with_uri_otype(self): + """Test create_term creates IRI Term for otype='u'""" + from trustgraph.query.triples.cassandra.service import create_term + from trustgraph.schema import IRI + + term = create_term('http://example.org/Alice', otype='u') + + assert term.type == IRI + assert term.iri == 'http://example.org/Alice' + + def test_create_term_with_literal_otype(self): + """Test create_term creates LITERAL Term for otype='l'""" + from trustgraph.query.triples.cassandra.service import create_term + from trustgraph.schema import LITERAL + + term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en') + + assert term.type == LITERAL + assert term.value == 'Alice Smith' + assert term.datatype == 'xsd:string' + assert term.language == 'en' + + def test_create_term_with_triple_otype(self): + """Test create_term creates IRI Term for otype='t'""" + from trustgraph.query.triples.cassandra.service import create_term + from trustgraph.schema import IRI + + term = create_term('http://example.org/statement1', otype='t') + + assert term.type == IRI + assert term.iri == 'http://example.org/statement1' + + def test_create_term_heuristic_fallback_uri(self): + """Test create_term uses URL heuristic when otype not provided""" + from trustgraph.query.triples.cassandra.service import create_term + from trustgraph.schema import IRI + + term = create_term('http://example.org/Alice') + + assert term.type == IRI + assert term.iri == 'http://example.org/Alice' + + def test_create_term_heuristic_fallback_literal(self): + """Test create_term uses literal heuristic when otype not provided""" + from trustgraph.query.triples.cassandra.service import create_term + from trustgraph.schema import LITERAL + + term = create_term('Alice Smith') + + assert term.type == LITERAL + assert term.value == 'Alice Smith' diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py new file mode 100644 index 00000000..47405431 --- /dev/null +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -0,0 +1,380 @@ +""" +Unit tests for trustgraph.embeddings.row_embeddings.embeddings +Tests the Stage 1 processor that computes embeddings for row index fields. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + + +class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): + """Test row embeddings processor functionality""" + + async def test_processor_initialization(self): + """Test basic processor initialization""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-row-embeddings' + } + + processor = Processor(**config) + + assert hasattr(processor, 'schemas') + assert processor.schemas == {} + assert processor.batch_size == 10 # default + + async def test_processor_initialization_with_custom_batch_size(self): + """Test processor initialization with custom batch size""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-row-embeddings', + 'batch_size': 25 + } + + processor = Processor(**config) + + assert processor.batch_size == 25 + + async def test_get_index_names_single_index(self): + """Test getting index names with single indexed field""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import RowSchema, Field + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + schema = RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + Field(name='email', type='text', indexed=False), + ] + ) + + index_names = processor.get_index_names(schema) + + # Should include primary key and indexed field + assert 'id' in index_names + assert 'name' in index_names + assert 'email' not in index_names + + async def test_get_index_names_no_indexes(self): + """Test getting index names when no fields are indexed""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import RowSchema, Field + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + schema = RowSchema( + name='logs', + description='Log records', + fields=[ + Field(name='timestamp', type='text'), + Field(name='message', type='text'), + ] + ) + + index_names = processor.get_index_names(schema) + + assert index_names == [] + + async def test_build_index_value_single_field(self): + """Test building index value for single field""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'id': 'CUST001', + 'name': 'John Doe', + 'email': 'john@example.com' + } + + result = processor.build_index_value(value_map, 'name') + + assert result == ['John Doe'] + + async def test_build_index_value_composite_index(self): + """Test building index value for composite index""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'first_name': 'John', + 'last_name': 'Doe', + 'city': 'New York' + } + + result = processor.build_index_value(value_map, 'first_name, last_name') + + assert result == ['John', 'Doe'] + + async def test_build_index_value_missing_field(self): + """Test building index value when field is missing""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'name': 'John Doe' + } + + result = processor.build_index_value(value_map, 'missing_field') + + assert result == [''] + + async def test_build_text_for_embedding_single_value(self): + """Test building text representation for single value""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + result = processor.build_text_for_embedding(['John Doe']) + + assert result == 'John Doe' + + async def test_build_text_for_embedding_multiple_values(self): + """Test building text representation for multiple values""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + result = processor.build_text_for_embedding(['John', 'Doe', 'NYC']) + + assert result == 'John Doe NYC' + + async def test_on_schema_config_loads_schemas(self): + """Test that schema configuration is loaded correctly""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + import json + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + + schema_def = { + 'name': 'customers', + 'description': 'Customer records', + 'fields': [ + {'name': 'id', 'type': 'text', 'primary_key': True}, + {'name': 'name', 'type': 'text', 'indexed': True}, + {'name': 'email', 'type': 'text'} + ] + } + + config_data = { + 'schema': { + 'customers': json.dumps(schema_def) + } + } + + await processor.on_schema_config(config_data, 1) + + assert 'customers' in processor.schemas + assert processor.schemas['customers'].name == 'customers' + assert len(processor.schemas['customers'].fields) == 3 + + async def test_on_schema_config_handles_missing_type(self): + """Test that missing schema type is handled gracefully""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + + config_data = { + 'other_type': {} + } + + await processor.on_schema_config(config_data, 1) + + assert processor.schemas == {} + + async def test_on_message_drops_unknown_collection(self): + """Test that messages for unknown collections are dropped""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + # No collections registered + + metadata = MagicMock() + metadata.user = 'unknown_user' + metadata.collection = 'unknown_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='customers', + values=[{'id': '123', 'name': 'Test'}] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + mock_flow = MagicMock() + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Flow should not be called for output + mock_flow.assert_not_called() + + async def test_on_message_drops_unknown_schema(self): + """Test that messages for unknown schemas are dropped""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + # No schemas registered + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='unknown_schema', + values=[{'id': '123', 'name': 'Test'}] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + mock_flow = MagicMock() + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Flow should not be called for output + mock_flow.assert_not_called() + + async def test_on_message_processes_embeddings(self): + """Test processing a message and computing embeddings""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject, RowSchema, Field + import json + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + # Set up schema + processor.schemas['customers'] = RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + ] + ) + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='customers', + values=[ + {'id': 'CUST001', 'name': 'John Doe'}, + {'id': 'CUST002', 'name': 'Jane Smith'} + ] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + # Mock the flow + mock_embeddings_request = AsyncMock() + mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]] + + mock_output = AsyncMock() + + def flow_factory(name): + if name == 'embeddings-request': + return mock_embeddings_request + elif name == 'output': + return mock_output + return MagicMock() + + mock_flow = MagicMock(side_effect=flow_factory) + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Should have called embed for each unique text + # 4 values: CUST001, John Doe, CUST002, Jane Smith + assert mock_embeddings_request.embed.call_count == 4 + + # Should have sent output + mock_output.send.assert_called() + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_extract/test_ontology/test_entity_contexts.py b/tests/unit/test_extract/test_ontology/test_entity_contexts.py index c867b05a..fde24e58 100644 --- a/tests/unit/test_extract/test_ontology/test_entity_contexts.py +++ b/tests/unit/test_extract/test_ontology/test_entity_contexts.py @@ -7,7 +7,7 @@ collecting labels and definitions for entity embedding and retrieval. import pytest from trustgraph.extract.kg.ontology.extract import Processor -from trustgraph.schema.core.primitives import Triple, Value +from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL from trustgraph.schema.knowledge.graph import EntityContext @@ -25,9 +25,9 @@ class TestEntityContextBuilding: """Test that entity context is built from rdfs:label.""" triples = [ Triple( - s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Cornish Pasty", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/cornish-pasty"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Cornish Pasty") ) ] @@ -35,16 +35,16 @@ class TestEntityContextBuilding: assert len(contexts) == 1, "Should create one entity context" assert isinstance(contexts[0], EntityContext) - assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty" + assert contexts[0].entity.iri == "https://example.com/entity/cornish-pasty" assert "Label: Cornish Pasty" in contexts[0].context def test_builds_context_from_definition(self, processor): """Test that entity context includes definitions.""" triples = [ Triple( - s=Value(value="https://example.com/entity/pasty", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value="A baked pastry filled with savory ingredients", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/pasty"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="A baked pastry filled with savory ingredients") ) ] @@ -57,14 +57,14 @@ class TestEntityContextBuilding: """Test that label and definition are combined in context.""" triples = [ Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Pasty Recipe", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Pasty Recipe") ), Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value="Traditional Cornish pastry recipe", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="Traditional Cornish pastry recipe") ) ] @@ -80,14 +80,14 @@ class TestEntityContextBuilding: """Test that only the first label is used in context.""" triples = [ Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="First Label", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="First Label") ), Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Second Label", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Second Label") ) ] @@ -101,14 +101,14 @@ class TestEntityContextBuilding: """Test that all definitions are included in context.""" triples = [ Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value="First definition", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="First definition") ), Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value="Second definition", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="Second definition") ) ] @@ -123,9 +123,9 @@ class TestEntityContextBuilding: """Test that schema.org description is treated as definition.""" triples = [ Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="https://schema.org/description", is_uri=True), - o=Value(value="A delicious food item", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="https://schema.org/description"), + o=Term(type=LITERAL, value="A delicious food item") ) ] @@ -138,26 +138,26 @@ class TestEntityContextBuilding: """Test that contexts are created for multiple entities.""" triples = [ Triple( - s=Value(value="https://example.com/entity/entity1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Entity One", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/entity1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Entity One") ), Triple( - s=Value(value="https://example.com/entity/entity2", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Entity Two", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/entity2"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Entity Two") ), Triple( - s=Value(value="https://example.com/entity/entity3", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Entity Three", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/entity3"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Entity Three") ) ] contexts = processor.build_entity_contexts(triples) assert len(contexts) == 3, "Should create context for each entity" - entity_uris = [ctx.entity.value for ctx in contexts] + entity_uris = [ctx.entity.iri for ctx in contexts] assert "https://example.com/entity/entity1" in entity_uris assert "https://example.com/entity/entity2" in entity_uris assert "https://example.com/entity/entity3" in entity_uris @@ -166,9 +166,9 @@ class TestEntityContextBuilding: """Test that URI objects are ignored (only literal labels/definitions).""" triples = [ Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=IRI, iri="https://example.com/some/uri") # URI, not literal ) ] @@ -181,14 +181,14 @@ class TestEntityContextBuilding: """Test that other predicates are ignored.""" triples = [ Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://example.com/Food", is_uri=True) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri="http://example.com/Food") ), Triple( - s=Value(value="https://example.com/entity/food1", is_uri=True), - p=Value(value="http://example.com/produces", is_uri=True), - o=Value(value="https://example.com/entity/food2", is_uri=True) + s=Term(type=IRI, iri="https://example.com/entity/food1"), + p=Term(type=IRI, iri="http://example.com/produces"), + o=Term(type=IRI, iri="https://example.com/entity/food2") ) ] @@ -205,29 +205,29 @@ class TestEntityContextBuilding: assert len(contexts) == 0, "Empty triple list should return empty contexts" - def test_entity_context_has_value_object(self, processor): - """Test that EntityContext.entity is a Value object.""" + def test_entity_context_has_term_object(self, processor): + """Test that EntityContext.entity is a Term object.""" triples = [ Triple( - s=Value(value="https://example.com/entity/test", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Test Entity", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/test"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Test Entity") ) ] contexts = processor.build_entity_contexts(triples) assert len(contexts) == 1 - assert isinstance(contexts[0].entity, Value), "Entity should be Value object" - assert contexts[0].entity.is_uri, "Entity should be marked as URI" + assert isinstance(contexts[0].entity, Term), "Entity should be Term object" + assert contexts[0].entity.type == IRI, "Entity should be IRI type" def test_entity_context_text_is_string(self, processor): """Test that EntityContext.context is a string.""" triples = [ Triple( - s=Value(value="https://example.com/entity/test", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Test Entity", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/test"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Test Entity") ) ] @@ -241,22 +241,22 @@ class TestEntityContextBuilding: triples = [ # Entity with label - should create context Triple( - s=Value(value="https://example.com/entity/entity1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Entity One", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/entity1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Entity One") ), # Entity with only rdf:type - should NOT create context Triple( - s=Value(value="https://example.com/entity/entity2", is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://example.com/Food", is_uri=True) + s=Term(type=IRI, iri="https://example.com/entity/entity2"), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri="http://example.com/Food") ) ] contexts = processor.build_entity_contexts(triples) assert len(contexts) == 1, "Should only create context for entity with label/definition" - assert contexts[0].entity.value == "https://example.com/entity/entity1" + assert contexts[0].entity.iri == "https://example.com/entity/entity1" class TestEntityContextEdgeCases: @@ -266,9 +266,9 @@ class TestEntityContextEdgeCases: """Test handling of unicode characters in labels.""" triples = [ Triple( - s=Value(value="https://example.com/entity/café", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Café Spécial", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/café"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Café Spécial") ) ] @@ -282,9 +282,9 @@ class TestEntityContextEdgeCases: long_def = "This is a very long definition " * 50 triples = [ Triple( - s=Value(value="https://example.com/entity/test", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value=long_def, is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/test"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value=long_def) ) ] @@ -297,9 +297,9 @@ class TestEntityContextEdgeCases: """Test handling of special characters in context text.""" triples = [ Triple( - s=Value(value="https://example.com/entity/test", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Test & Entity \"quotes\"", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/test"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Test & Entity \"quotes\"") ) ] @@ -313,27 +313,27 @@ class TestEntityContextEdgeCases: triples = [ # Label - relevant Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), - o=Value(value="Cornish Pasty Recipe", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"), + o=Term(type=LITERAL, value="Cornish Pasty Recipe") ), # Type - irrelevant Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://example.com/Recipe", is_uri=True) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri="http://example.com/Recipe") ), # Property - irrelevant Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://example.com/produces", is_uri=True), - o=Value(value="https://example.com/entity/pasty", is_uri=True) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://example.com/produces"), + o=Term(type=IRI, iri="https://example.com/entity/pasty") ), # Definition - relevant Triple( - s=Value(value="https://example.com/entity/recipe1", is_uri=True), - p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), - o=Value(value="Traditional British pastry recipe", is_uri=False) + s=Term(type=IRI, iri="https://example.com/entity/recipe1"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="Traditional British pastry recipe") ) ] diff --git a/tests/unit/test_extract/test_ontology/test_ontology_triples.py b/tests/unit/test_extract/test_ontology/test_ontology_triples.py index 70ade79d..50e2ef3b 100644 --- a/tests/unit/test_extract/test_ontology/test_ontology_triples.py +++ b/tests/unit/test_extract/test_ontology/test_ontology_triples.py @@ -9,7 +9,7 @@ the knowledge graph. import pytest from trustgraph.extract.kg.ontology.extract import Processor from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset -from trustgraph.schema.core.primitives import Triple, Value +from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL @pytest.fixture @@ -92,12 +92,12 @@ class TestOntologyTripleGeneration: # Find type triples for Recipe class recipe_type_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/Recipe" - and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + if t.s.iri == "http://purl.org/ontology/fo/Recipe" + and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" ] assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class" - assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \ + assert recipe_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#Class", \ "Class type should be owl:Class" def test_generates_class_labels(self, extractor, sample_ontology_subset): @@ -107,14 +107,14 @@ class TestOntologyTripleGeneration: # Find label triples for Recipe class recipe_label_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/Recipe" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + if t.s.iri == "http://purl.org/ontology/fo/Recipe" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label" ] assert len(recipe_label_triples) == 1, "Should generate label triple for class" assert recipe_label_triples[0].o.value == "Recipe", \ "Label should match class label from ontology" - assert not recipe_label_triples[0].o.is_uri, \ + assert recipe_label_triples[0].o.type == LITERAL, \ "Label should be a literal, not URI" def test_generates_class_comments(self, extractor, sample_ontology_subset): @@ -124,8 +124,8 @@ class TestOntologyTripleGeneration: # Find comment triples for Recipe class recipe_comment_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/Recipe" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment" + if t.s.iri == "http://purl.org/ontology/fo/Recipe" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#comment" ] assert len(recipe_comment_triples) == 1, "Should generate comment triple for class" @@ -139,13 +139,13 @@ class TestOntologyTripleGeneration: # Find type triples for ingredients property ingredients_type_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/ingredients" - and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + if t.s.iri == "http://purl.org/ontology/fo/ingredients" + and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" ] assert len(ingredients_type_triples) == 1, \ "Should generate exactly one type triple per object property" - assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \ + assert ingredients_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#ObjectProperty", \ "Object property type should be owl:ObjectProperty" def test_generates_object_property_labels(self, extractor, sample_ontology_subset): @@ -155,8 +155,8 @@ class TestOntologyTripleGeneration: # Find label triples for ingredients property ingredients_label_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/ingredients" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + if t.s.iri == "http://purl.org/ontology/fo/ingredients" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label" ] assert len(ingredients_label_triples) == 1, \ @@ -171,15 +171,15 @@ class TestOntologyTripleGeneration: # Find domain triples for ingredients property ingredients_domain_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/ingredients" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain" + if t.s.iri == "http://purl.org/ontology/fo/ingredients" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#domain" ] assert len(ingredients_domain_triples) == 1, \ "Should generate domain triple for object property" - assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \ + assert ingredients_domain_triples[0].o.iri == "http://purl.org/ontology/fo/Recipe", \ "Domain should be Recipe class URI" - assert ingredients_domain_triples[0].o.is_uri, \ + assert ingredients_domain_triples[0].o.type == IRI, \ "Domain should be a URI reference" def test_generates_object_property_range(self, extractor, sample_ontology_subset): @@ -189,13 +189,13 @@ class TestOntologyTripleGeneration: # Find range triples for produces property produces_range_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/produces" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range" + if t.s.iri == "http://purl.org/ontology/fo/produces" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range" ] assert len(produces_range_triples) == 1, \ "Should generate range triple for object property" - assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \ + assert produces_range_triples[0].o.iri == "http://purl.org/ontology/fo/Food", \ "Range should be Food class URI" def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset): @@ -205,13 +205,13 @@ class TestOntologyTripleGeneration: # Find type triples for serves property serves_type_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/serves" - and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + if t.s.iri == "http://purl.org/ontology/fo/serves" + and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" ] assert len(serves_type_triples) == 1, \ "Should generate exactly one type triple per datatype property" - assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \ + assert serves_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#DatatypeProperty", \ "Datatype property type should be owl:DatatypeProperty" def test_generates_datatype_property_range(self, extractor, sample_ontology_subset): @@ -221,13 +221,13 @@ class TestOntologyTripleGeneration: # Find range triples for serves property serves_range_triples = [ t for t in triples - if t.s.value == "http://purl.org/ontology/fo/serves" - and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range" + if t.s.iri == "http://purl.org/ontology/fo/serves" + and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range" ] assert len(serves_range_triples) == 1, \ "Should generate range triple for datatype property" - assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \ + assert serves_range_triples[0].o.iri == "http://www.w3.org/2001/XMLSchema#string", \ "Range should be XSD type URI (xsd:string expanded)" def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset): @@ -236,9 +236,9 @@ class TestOntologyTripleGeneration: # Count unique class subjects class_subjects = set( - t.s.value for t in triples - if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" - and t.o.value == "http://www.w3.org/2002/07/owl#Class" + t.s.iri for t in triples + if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + and t.o.iri == "http://www.w3.org/2002/07/owl#Class" ) assert len(class_subjects) == 3, \ @@ -250,9 +250,9 @@ class TestOntologyTripleGeneration: # Count unique property subjects (object + datatype properties) property_subjects = set( - t.s.value for t in triples - if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" - and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value) + t.s.iri for t in triples + if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + and ("ObjectProperty" in t.o.iri or "DatatypeProperty" in t.o.iri) ) assert len(property_subjects) == 3, \ @@ -276,7 +276,7 @@ class TestOntologyTripleGeneration: # Should still generate proper RDF triples despite dict field names label_triples = [ t for t in triples - if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + if t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label" ] assert len(label_triples) > 0, \ "Should generate rdfs:label triples from dict 'labels' field" diff --git a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py index e6d5bf36..9f9c8551 100644 --- a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py +++ b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py @@ -8,7 +8,7 @@ and extracts/validates triples from LLM responses. import pytest from trustgraph.extract.kg.ontology.extract import Processor from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset -from trustgraph.schema.core.primitives import Triple, Value +from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL @pytest.fixture @@ -248,9 +248,9 @@ class TestTripleParsing: validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) assert len(validated) == 1, "Should parse one valid triple" - assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty" - assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" - assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe" + assert validated[0].s.iri == "https://trustgraph.ai/food/cornish-pasty" + assert validated[0].p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe" def test_parse_multiple_triples(self, extractor, sample_ontology_subset): """Test parsing multiple triples.""" @@ -307,11 +307,11 @@ class TestTripleParsing: assert len(validated) == 1 # Subject should be expanded to entity URI - assert validated[0].s.value.startswith("https://trustgraph.ai/food/") + assert validated[0].s.iri.startswith("https://trustgraph.ai/food/") # Predicate should be expanded to ontology URI - assert validated[0].p.value == "http://purl.org/ontology/fo/produces" + assert validated[0].p.iri == "http://purl.org/ontology/fo/produces" # Object should be expanded to class URI - assert validated[0].o.value == "http://purl.org/ontology/fo/Food" + assert validated[0].o.iri == "http://purl.org/ontology/fo/Food" def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset): """Test that Triple objects are properly created.""" @@ -324,12 +324,12 @@ class TestTripleParsing: assert len(validated) == 1 triple = validated[0] assert isinstance(triple, Triple), "Should create Triple objects" - assert isinstance(triple.s, Value), "Subject should be Value object" - assert isinstance(triple.p, Value), "Predicate should be Value object" - assert isinstance(triple.o, Value), "Object should be Value object" - assert triple.s.is_uri, "Subject should be marked as URI" - assert triple.p.is_uri, "Predicate should be marked as URI" - assert not triple.o.is_uri, "Object literal should not be marked as URI" + assert isinstance(triple.s, Term), "Subject should be Term object" + assert isinstance(triple.p, Term), "Predicate should be Term object" + assert isinstance(triple.o, Term), "Object should be Term object" + assert triple.s.type == IRI, "Subject should be IRI type" + assert triple.p.type == IRI, "Predicate should be IRI type" + assert triple.o.type == LITERAL, "Object literal should be LITERAL type" class TestURIExpansionInExtraction: @@ -343,8 +343,8 @@ class TestURIExpansionInExtraction: validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) - assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe" - assert validated[0].o.is_uri, "Class reference should be URI" + assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe" + assert validated[0].o.type == IRI, "Class reference should be URI" def test_expands_property_names(self, extractor, sample_ontology_subset): """Test that property names are expanded to full URIs.""" @@ -354,7 +354,7 @@ class TestURIExpansionInExtraction: validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) - assert validated[0].p.value == "http://purl.org/ontology/fo/produces" + assert validated[0].p.iri == "http://purl.org/ontology/fo/produces" def test_expands_entity_instances(self, extractor, sample_ontology_subset): """Test that entity instances get constructed URIs.""" @@ -364,8 +364,8 @@ class TestURIExpansionInExtraction: validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) - assert validated[0].s.value.startswith("https://trustgraph.ai/food/") - assert "my-special-recipe" in validated[0].s.value + assert validated[0].s.iri.startswith("https://trustgraph.ai/food/") + assert "my-special-recipe" in validated[0].s.iri class TestEdgeCases: diff --git a/tests/unit/test_gateway/test_dispatch_serialize.py b/tests/unit/test_gateway/test_dispatch_serialize.py index e117629b..5d546adf 100644 --- a/tests/unit/test_gateway/test_dispatch_serialize.py +++ b/tests/unit/test_gateway/test_dispatch_serialize.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value -from trustgraph.schema import Value, Triple +from trustgraph.schema import Term, Triple, IRI, LITERAL class TestDispatchSerialize: @@ -14,55 +14,55 @@ class TestDispatchSerialize: def test_to_value_with_uri(self): """Test to_value function with URI""" - input_data = {"v": "http://example.com/resource", "e": True} - + input_data = {"t": "i", "i": "http://example.com/resource"} + result = to_value(input_data) - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI def test_to_value_with_literal(self): """Test to_value function with literal value""" - input_data = {"v": "literal string", "e": False} - + input_data = {"t": "l", "v": "literal string"} + result = to_value(input_data) - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "literal string" - assert result.is_uri is False + assert result.type == LITERAL def test_to_subgraph_with_multiple_triples(self): """Test to_subgraph function with multiple triples""" input_data = [ { - "s": {"v": "subject1", "e": True}, - "p": {"v": "predicate1", "e": True}, - "o": {"v": "object1", "e": False} + "s": {"t": "i", "i": "subject1"}, + "p": {"t": "i", "i": "predicate1"}, + "o": {"t": "l", "v": "object1"} }, { - "s": {"v": "subject2", "e": False}, - "p": {"v": "predicate2", "e": True}, - "o": {"v": "object2", "e": True} + "s": {"t": "l", "v": "subject2"}, + "p": {"t": "i", "i": "predicate2"}, + "o": {"t": "i", "i": "object2"} } ] - + result = to_subgraph(input_data) - + assert len(result) == 2 assert all(isinstance(triple, Triple) for triple in result) - + # Check first triple - assert result[0].s.value == "subject1" - assert result[0].s.is_uri is True - assert result[0].p.value == "predicate1" - assert result[0].p.is_uri is True + assert result[0].s.iri == "subject1" + assert result[0].s.type == IRI + assert result[0].p.iri == "predicate1" + assert result[0].p.type == IRI assert result[0].o.value == "object1" - assert result[0].o.is_uri is False - + assert result[0].o.type == LITERAL + # Check second triple assert result[1].s.value == "subject2" - assert result[1].s.is_uri is False + assert result[1].s.type == LITERAL def test_to_subgraph_with_empty_list(self): """Test to_subgraph function with empty input""" @@ -74,16 +74,16 @@ class TestDispatchSerialize: def test_serialize_value_with_uri(self): """Test serialize_value function with URI value""" - value = Value(value="http://example.com/test", is_uri=True) - - result = serialize_value(value) - - assert result == {"v": "http://example.com/test", "e": True} + term = Term(type=IRI, iri="http://example.com/test") + + result = serialize_value(term) + + assert result == {"t": "i", "i": "http://example.com/test"} def test_serialize_value_with_literal(self): """Test serialize_value function with literal value""" - value = Value(value="test literal", is_uri=False) - - result = serialize_value(value) - - assert result == {"v": "test literal", "e": False} \ No newline at end of file + term = Term(type=LITERAL, value="test literal") + + result = serialize_value(term) + + assert result == {"t": "l", "v": "test literal"} \ No newline at end of file diff --git a/tests/unit/test_gateway/test_objects_import_dispatcher.py b/tests/unit/test_gateway/test_rows_import_dispatcher.py similarity index 83% rename from tests/unit/test_gateway/test_objects_import_dispatcher.py rename to tests/unit/test_gateway/test_rows_import_dispatcher.py index 0332c1a1..ab72cae1 100644 --- a/tests/unit/test_gateway/test_objects_import_dispatcher.py +++ b/tests/unit/test_gateway/test_rows_import_dispatcher.py @@ -1,7 +1,7 @@ """ -Unit tests for objects import dispatcher. +Unit tests for rows import dispatcher. -Tests the business logic of objects import dispatcher +Tests the business logic of rows import dispatcher while mocking the Publisher and websocket components. """ @@ -11,7 +11,7 @@ import asyncio from unittest.mock import Mock, AsyncMock, patch, MagicMock from aiohttp import web -from trustgraph.gateway.dispatch.objects_import import ObjectsImport +from trustgraph.gateway.dispatch.rows_import import RowsImport from trustgraph.schema import Metadata, ExtractedObject @@ -92,16 +92,16 @@ def minimal_objects_message(): } -class TestObjectsImportInitialization: - """Test ObjectsImport initialization.""" +class TestRowsImportInitialization: + """Test RowsImport initialization.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): - """Test that ObjectsImport creates Publisher with correct parameters.""" + """Test that RowsImport creates Publisher with correct parameters.""" mock_publisher_instance = Mock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -116,28 +116,28 @@ class TestObjectsImportInitialization: ) # Verify instance variables are set correctly - assert objects_import.ws == mock_websocket - assert objects_import.running == mock_running - assert objects_import.publisher == mock_publisher_instance + assert rows_import.ws == mock_websocket + assert rows_import.running == mock_running + assert rows_import.publisher == mock_publisher_instance - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): - """Test that ObjectsImport stores all required references.""" - objects_import = ObjectsImport( + """Test that RowsImport stores all required references.""" + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="objects-queue" ) - assert objects_import.ws is mock_websocket - assert objects_import.running is mock_running + assert rows_import.ws is mock_websocket + assert rows_import.running is mock_running -class TestObjectsImportLifecycle: - """Test ObjectsImport lifecycle methods.""" +class TestRowsImportLifecycle: + """Test RowsImport lifecycle methods.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that start() calls publisher.start().""" @@ -145,18 +145,18 @@ class TestObjectsImportLifecycle: mock_publisher_instance.start = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.start() + await rows_import.start() mock_publisher_instance.start.assert_called_once() - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that destroy() properly stops publisher and closes websocket.""" @@ -164,21 +164,21 @@ class TestObjectsImportLifecycle: mock_publisher_instance.stop = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.destroy() + await rows_import.destroy() # Verify sequence of operations mock_running.stop.assert_called_once() mock_publisher_instance.stop.assert_called_once() mock_websocket.close.assert_called_once() - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running): """Test that destroy() handles None websocket gracefully.""" @@ -186,7 +186,7 @@ class TestObjectsImportLifecycle: mock_publisher_instance.stop = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=None, # None websocket running=mock_running, backend=mock_backend, @@ -194,16 +194,16 @@ class TestObjectsImportLifecycle: ) # Should not raise exception - await objects_import.destroy() + await rows_import.destroy() mock_running.stop.assert_called_once() mock_publisher_instance.stop.assert_called_once() -class TestObjectsImportMessageProcessing: - """Test ObjectsImport message processing.""" +class TestRowsImportMessageProcessing: + """Test RowsImport message processing.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() processes complete message correctly.""" @@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = sample_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing: assert sent_object.metadata.collection == "testcollection" assert len(sent_object.metadata.metadata) == 1 # One triple in metadata - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message): """Test that receive() handles message with minimal required fields.""" @@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = minimal_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing: assert sent_object.source_span == "" # Default value assert len(sent_object.metadata.metadata) == 0 # Default empty list - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() uses appropriate default values for optional fields.""" @@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = message_data - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Get the sent object and verify defaults sent_object = mock_publisher_instance.send.call_args[0][1] @@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing: assert sent_object.source_span == "" -class TestObjectsImportRunMethod: - """Test ObjectsImport run method.""" +class TestRowsImportRunMethod: + """Test RowsImport run method.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep') @pytest.mark.asyncio async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that run() loops while running.get() returns True.""" @@ -331,14 +331,14 @@ class TestObjectsImportRunMethod: # Set up running state to return True twice, then False mock_running.get.side_effect = [True, True, False] - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.run() + await rows_import.run() # Verify sleep was called twice (for the two True iterations) assert mock_sleep.call_count == 2 @@ -348,10 +348,10 @@ class TestObjectsImportRunMethod: mock_websocket.close.assert_called_once() # Verify websocket was set to None - assert objects_import.ws is None + assert rows_import.ws is None - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep') @pytest.mark.asyncio async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running): """Test that run() handles None websocket gracefully.""" @@ -360,7 +360,7 @@ class TestObjectsImportRunMethod: mock_running.get.return_value = False # Exit immediately - objects_import = ObjectsImport( + rows_import = RowsImport( ws=None, # None websocket running=mock_running, backend=mock_backend, @@ -368,14 +368,14 @@ class TestObjectsImportRunMethod: ) # Should not raise exception - await objects_import.run() + await rows_import.run() # Verify websocket remains None - assert objects_import.ws is None + assert rows_import.ws is None -class TestObjectsImportBatchProcessing: - """Test ObjectsImport batch processing functionality.""" +class TestRowsImportBatchProcessing: + """Test RowsImport batch processing functionality.""" @pytest.fixture def batch_objects_message(self): @@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing: "source_span": "Multiple people found in document" } - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message): """Test that receive() processes batch message correctly.""" @@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing: mock_msg = Mock() mock_msg.json.return_value = batch_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing: assert sent_object.confidence == 0.85 assert sent_object.source_span == "Multiple people found in document" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles empty batch correctly.""" @@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing: mock_msg = Mock() mock_msg.json.return_value = empty_batch_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Should still send the message mock_publisher_instance.send.assert_called_once() @@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing: assert len(sent_object.values) == 0 -class TestObjectsImportErrorHandling: - """Test error handling in ObjectsImport.""" +class TestRowsImportErrorHandling: + """Test error handling in RowsImport.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() propagates publisher send errors.""" @@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling: mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error")) mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling: mock_msg.json.return_value = sample_objects_message with pytest.raises(Exception, match="Publisher error"): - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles malformed JSON appropriately.""" mock_publisher_class.return_value = Mock() - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling: mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) with pytest.raises(json.JSONDecodeError): - await objects_import.receive(mock_msg) \ No newline at end of file + await rows_import.receive(mock_msg) \ No newline at end of file diff --git a/tests/unit/test_knowledge_graph/conftest.py b/tests/unit/test_knowledge_graph/conftest.py index d4a83054..e7f83b58 100644 --- a/tests/unit/test_knowledge_graph/conftest.py +++ b/tests/unit/test_knowledge_graph/conftest.py @@ -6,11 +6,21 @@ import pytest from unittest.mock import Mock, AsyncMock # Mock schema classes for testing -class Value: - def __init__(self, value, is_uri, type): - self.value = value - self.is_uri = is_uri +# Term type constants +IRI = "i" +LITERAL = "l" +BLANK = "b" +TRIPLE = "t" + +class Term: + def __init__(self, type, iri=None, value=None, id=None, datatype=None, language=None, triple=None): self.type = type + self.iri = iri + self.value = value + self.id = id + self.datatype = datatype + self.language = language + self.triple = triple class Triple: def __init__(self, s, p, o): @@ -66,32 +76,30 @@ def sample_relationships(): @pytest.fixture -def sample_value_uri(): - """Sample URI Value object""" - return Value( - value="http://example.com/person/john-smith", - is_uri=True, - type="" +def sample_term_uri(): + """Sample URI Term object""" + return Term( + type=IRI, + iri="http://example.com/person/john-smith" ) @pytest.fixture -def sample_value_literal(): - """Sample literal Value object""" - return Value( - value="John Smith", - is_uri=False, - type="string" +def sample_term_literal(): + """Sample literal Term object""" + return Term( + type=LITERAL, + value="John Smith" ) @pytest.fixture -def sample_triple(sample_value_uri, sample_value_literal): +def sample_triple(sample_term_uri, sample_term_literal): """Sample Triple object""" return Triple( - s=sample_value_uri, - p=Value(value="http://schema.org/name", is_uri=True, type=""), - o=sample_value_literal + s=sample_term_uri, + p=Term(type=IRI, iri="http://schema.org/name"), + o=sample_term_literal ) diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction.py b/tests/unit/test_knowledge_graph/test_agent_extraction.py index be5553df..a3a0f9a7 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction.py @@ -11,7 +11,7 @@ import json from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor -from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error +from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from trustgraph.template.prompt_manager import PromptManager @@ -33,7 +33,7 @@ class TestAgentKgExtractor: # Set up the methods we want to test extractor.to_uri = real_extractor.to_uri - extractor.parse_json = real_extractor.parse_json + extractor.parse_jsonl = real_extractor.parse_jsonl extractor.process_extraction_data = real_extractor.process_extraction_data extractor.emit_triples = real_extractor.emit_triples extractor.emit_entity_contexts = real_extractor.emit_entity_contexts @@ -53,48 +53,49 @@ class TestAgentKgExtractor: id="doc123", metadata=[ Triple( - s=Value(value="doc123", is_uri=True), - p=Value(value="http://example.org/type", is_uri=True), - o=Value(value="document", is_uri=False) + s=Term(type=IRI, iri="doc123"), + p=Term(type=IRI, iri="http://example.org/type"), + o=Term(type=LITERAL, value="document") ) ] ) @pytest.fixture def sample_extraction_data(self): - """Sample extraction data in expected format""" - return { - "definitions": [ - { - "entity": "Machine Learning", - "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." - }, - { - "entity": "Neural Networks", - "definition": "Computing systems inspired by biological neural networks that process information." - } - ], - "relationships": [ - { - "subject": "Machine Learning", - "predicate": "is_subset_of", - "object": "Artificial Intelligence", - "object-entity": True - }, - { - "subject": "Neural Networks", - "predicate": "used_in", - "object": "Machine Learning", - "object-entity": True - }, - { - "subject": "Deep Learning", - "predicate": "accuracy", - "object": "95%", - "object-entity": False - } - ] - } + """Sample extraction data in JSONL format (list with type discriminators)""" + return [ + { + "type": "definition", + "entity": "Machine Learning", + "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." + }, + { + "type": "definition", + "entity": "Neural Networks", + "definition": "Computing systems inspired by biological neural networks that process information." + }, + { + "type": "relationship", + "subject": "Machine Learning", + "predicate": "is_subset_of", + "object": "Artificial Intelligence", + "object-entity": True + }, + { + "type": "relationship", + "subject": "Neural Networks", + "predicate": "used_in", + "object": "Machine Learning", + "object-entity": True + }, + { + "type": "relationship", + "subject": "Deep Learning", + "predicate": "accuracy", + "object": "95%", + "object-entity": False + } + ] def test_to_uri_conversion(self, agent_extractor): """Test URI conversion for entities""" @@ -113,148 +114,147 @@ class TestAgentKgExtractor: expected = f"{TRUSTGRAPH_ENTITIES}" assert uri == expected - def test_parse_json_with_code_blocks(self, agent_extractor): - """Test JSON parsing from code blocks""" - # Test JSON in code blocks + def test_parse_jsonl_with_code_blocks(self, agent_extractor): + """Test JSONL parsing from code blocks""" + # Test JSONL in code blocks - note: JSON uses lowercase true/false response = '''```json - { - "definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}], - "relationships": [] - } - ```''' - - result = agent_extractor.parse_json(response) - - assert result["definitions"][0]["entity"] == "AI" - assert result["definitions"][0]["definition"] == "Artificial Intelligence" - assert result["relationships"] == [] +{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"} +{"type": "relationship", "subject": "AI", "predicate": "is", "object": "technology", "object-entity": false} +```''' - def test_parse_json_without_code_blocks(self, agent_extractor): - """Test JSON parsing without code blocks""" - response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}''' - - result = agent_extractor.parse_json(response) - - assert result["definitions"][0]["entity"] == "ML" - assert result["definitions"][0]["definition"] == "Machine Learning" + result = agent_extractor.parse_jsonl(response) - def test_parse_json_invalid_format(self, agent_extractor): - """Test JSON parsing with invalid format""" - invalid_response = "This is not JSON at all" - - with pytest.raises(json.JSONDecodeError): - agent_extractor.parse_json(invalid_response) + assert len(result) == 2 + assert result[0]["entity"] == "AI" + assert result[0]["definition"] == "Artificial Intelligence" + assert result[1]["type"] == "relationship" - def test_parse_json_malformed_code_blocks(self, agent_extractor): - """Test JSON parsing with malformed code blocks""" - # Missing closing backticks - response = '''```json - {"definitions": [], "relationships": []} - ''' - - # Should still parse the JSON content - with pytest.raises(json.JSONDecodeError): - agent_extractor.parse_json(response) + def test_parse_jsonl_without_code_blocks(self, agent_extractor): + """Test JSONL parsing without code blocks""" + response = '''{"type": "definition", "entity": "ML", "definition": "Machine Learning"} +{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}''' + + result = agent_extractor.parse_jsonl(response) + + assert len(result) == 2 + assert result[0]["entity"] == "ML" + assert result[1]["entity"] == "AI" + + def test_parse_jsonl_invalid_lines_skipped(self, agent_extractor): + """Test JSONL parsing skips invalid lines gracefully""" + response = '''{"type": "definition", "entity": "Valid", "definition": "Valid def"} +This is not JSON at all +{"type": "definition", "entity": "Also Valid", "definition": "Another def"}''' + + result = agent_extractor.parse_jsonl(response) + + # Should get 2 valid objects, skipping the invalid line + assert len(result) == 2 + assert result[0]["entity"] == "Valid" + assert result[1]["entity"] == "Also Valid" + + def test_parse_jsonl_truncation_resilience(self, agent_extractor): + """Test JSONL parsing handles truncated responses""" + # Simulates output cut off mid-line + response = '''{"type": "definition", "entity": "Complete", "definition": "Full def"} +{"type": "definition", "entity": "Trunca''' + + result = agent_extractor.parse_jsonl(response) + + # Should get 1 valid object, the truncated line is skipped + assert len(result) == 1 + assert result[0]["entity"] == "Complete" def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata): """Test processing of definition data""" - data = { - "definitions": [ - { - "entity": "Machine Learning", - "definition": "A subset of AI that enables learning from data." - } - ], - "relationships": [] - } - + data = [ + { + "type": "definition", + "entity": "Machine Learning", + "definition": "A subset of AI that enables learning from data." + } + ] + triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) # Check entity label triple - label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None) + label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None) assert label_triple is not None - assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" - assert label_triple.s.is_uri == True - assert label_triple.o.is_uri == False - + assert label_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" + assert label_triple.s.type == IRI + assert label_triple.o.type == LITERAL + # Check definition triple - def_triple = next((t for t in triples if t.p.value == DEFINITION), None) + def_triple = next((t for t in triples if t.p.iri == DEFINITION), None) assert def_triple is not None - assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" + assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" assert def_triple.o.value == "A subset of AI that enables learning from data." - + # Check subject-of triple - subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None) + subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None) assert subject_of_triple is not None - assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" - assert subject_of_triple.o.value == "doc123" - + assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" + assert subject_of_triple.o.iri == "doc123" + # Check entity context assert len(entity_contexts) == 1 - assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" + assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" assert entity_contexts[0].context == "A subset of AI that enables learning from data." def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata): """Test processing of relationship data""" - data = { - "definitions": [], - "relationships": [ - { - "subject": "Machine Learning", - "predicate": "is_subset_of", - "object": "Artificial Intelligence", - "object-entity": True - } - ] - } - + data = [ + { + "type": "relationship", + "subject": "Machine Learning", + "predicate": "is_subset_of", + "object": "Artificial Intelligence", + "object-entity": True + } + ] + triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) # Check that subject, predicate, and object labels are created subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of" - + # Find label triples - subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None) + subject_label = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == RDF_LABEL), None) assert subject_label is not None assert subject_label.o.value == "Machine Learning" - - predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None) + + predicate_label = next((t for t in triples if t.s.iri == predicate_uri and t.p.iri == RDF_LABEL), None) assert predicate_label is not None assert predicate_label.o.value == "is_subset_of" - - # Check main relationship triple - # NOTE: Current implementation has bugs: - # 1. Uses data.get("object-entity") instead of rel.get("object-entity") - # 2. Sets object_value to predicate_uri instead of actual object URI - # This test documents the current buggy behavior - rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None) + + # Check main relationship triple + object_uri = f"{TRUSTGRAPH_ENTITIES}Artificial%20Intelligence" + rel_triple = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == predicate_uri), None) assert rel_triple is not None - # Due to bug, object value is set to predicate_uri - assert rel_triple.o.value == predicate_uri - + assert rel_triple.o.iri == object_uri + assert rel_triple.o.type == IRI + # Check subject-of relationships - subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"] + subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"] assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata): """Test processing of relationships with literal objects""" - data = { - "definitions": [], - "relationships": [ - { - "subject": "Deep Learning", - "predicate": "accuracy", - "object": "95%", - "object-entity": False - } - ] - } - + data = [ + { + "type": "relationship", + "subject": "Deep Learning", + "predicate": "accuracy", + "object": "95%", + "object-entity": False + } + ] + triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) - + # Check that object labels are not created for literal objects - object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"] + object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"] # Based on the code logic, it should not create object labels for non-entity objects # But there might be a bug in the original implementation @@ -263,75 +263,62 @@ class TestAgentKgExtractor: triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata) # Check that we have both definition and relationship triples - definition_triples = [t for t in triples if t.p.value == DEFINITION] + definition_triples = [t for t in triples if t.p.iri == DEFINITION] assert len(definition_triples) == 2 # Two definitions - + # Check entity contexts are created for definitions assert len(entity_contexts) == 2 - entity_uris = [ec.entity.value for ec in entity_contexts] + entity_uris = [ec.entity.iri for ec in entity_contexts] assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris def test_process_extraction_data_no_metadata_id(self, agent_extractor): """Test processing when metadata has no ID""" metadata = Metadata(id=None, metadata=[]) - data = { - "definitions": [ - {"entity": "Test Entity", "definition": "Test definition"} - ], - "relationships": [] - } - + data = [ + {"type": "definition", "entity": "Test Entity", "definition": "Test definition"} + ] + triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should not create subject-of relationships when no metadata ID - subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF] + subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF] assert len(subject_of_triples) == 0 - + # Should still create entity contexts assert len(entity_contexts) == 1 def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata): """Test processing of empty extraction data""" - data = {"definitions": [], "relationships": []} - - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) - - # Should only have metadata triples - assert len(entity_contexts) == 0 - # Triples should only contain metadata triples if any + data = [] - def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata): - """Test processing data with missing keys""" - # Test missing definitions key - data = {"relationships": []} triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) + + # Should have no entity contexts assert len(entity_contexts) == 0 - - # Test missing relationships key - data = {"definitions": []} + # Triples should be empty + assert len(triples) == 0 + + def test_process_extraction_data_unknown_types_ignored(self, agent_extractor, sample_metadata): + """Test processing data with unknown type values""" + data = [ + {"type": "definition", "entity": "Valid", "definition": "Valid def"}, + {"type": "unknown_type", "foo": "bar"}, # Unknown type - should be ignored + {"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True} + ] + triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) - assert len(entity_contexts) == 0 - - # Test completely missing keys - data = {} - triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata) - assert len(entity_contexts) == 0 + + # Should process valid items and ignore unknown types + assert len(entity_contexts) == 1 # Only the definition creates entity context def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata): """Test processing data with malformed entries""" - # Test definition missing required fields - data = { - "definitions": [ - {"entity": "Test"}, # Missing definition - {"definition": "Test def"} # Missing entity - ], - "relationships": [ - {"subject": "A", "predicate": "rel"}, # Missing object - {"subject": "B", "object": "C"} # Missing predicate - ] - } - + # Test items missing required fields - should raise KeyError + data = [ + {"type": "definition", "entity": "Test"}, # Missing definition + ] + # Should handle gracefully or raise appropriate errors with pytest.raises(KeyError): agent_extractor.process_extraction_data(data, sample_metadata) @@ -340,17 +327,17 @@ class TestAgentKgExtractor: async def test_emit_triples(self, agent_extractor, sample_metadata): """Test emitting triples to publisher""" mock_publisher = AsyncMock() - + test_triples = [ Triple( - s=Value(value="test:subject", is_uri=True), - p=Value(value="test:predicate", is_uri=True), - o=Value(value="test object", is_uri=False) + s=Term(type=IRI, iri="test:subject"), + p=Term(type=IRI, iri="test:predicate"), + o=Term(type=LITERAL, value="test object") ) ] - + await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples) - + mock_publisher.send.assert_called_once() sent_triples = mock_publisher.send.call_args[0][0] assert isinstance(sent_triples, Triples) @@ -361,22 +348,22 @@ class TestAgentKgExtractor: # Note: metadata.metadata is now empty array in the new implementation assert sent_triples.metadata.metadata == [] assert len(sent_triples.triples) == 1 - assert sent_triples.triples[0].s.value == "test:subject" + assert sent_triples.triples[0].s.iri == "test:subject" @pytest.mark.asyncio async def test_emit_entity_contexts(self, agent_extractor, sample_metadata): """Test emitting entity contexts to publisher""" mock_publisher = AsyncMock() - + test_contexts = [ EntityContext( - entity=Value(value="test:entity", is_uri=True), + entity=Term(type=IRI, iri="test:entity"), context="Test context" ) ] - + await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts) - + mock_publisher.send.assert_called_once() sent_contexts = mock_publisher.send.call_args[0][0] assert isinstance(sent_contexts, EntityContexts) @@ -387,7 +374,7 @@ class TestAgentKgExtractor: # Note: metadata.metadata is now empty array in the new implementation assert sent_contexts.metadata.metadata == [] assert len(sent_contexts.entities) == 1 - assert sent_contexts.entities[0].entity.value == "test:entity" + assert sent_contexts.entities[0].entity.iri == "test:entity" def test_agent_extractor_initialization_params(self): """Test agent extractor parameter validation""" diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py b/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py index c69df8c4..f66e5da6 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction_edge_cases.py @@ -11,7 +11,7 @@ import urllib.parse from unittest.mock import AsyncMock, MagicMock from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor -from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value +from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF @@ -32,11 +32,11 @@ class TestAgentKgExtractionEdgeCases: # Set up the methods we want to test extractor.to_uri = real_extractor.to_uri - extractor.parse_json = real_extractor.parse_json + extractor.parse_jsonl = real_extractor.parse_jsonl extractor.process_extraction_data = real_extractor.process_extraction_data extractor.emit_triples = real_extractor.emit_triples extractor.emit_entity_contexts = real_extractor.emit_entity_contexts - + return extractor def test_to_uri_special_characters(self, agent_extractor): @@ -85,146 +85,116 @@ class TestAgentKgExtractionEdgeCases: # Verify the URI is properly encoded assert unicode_text not in uri # Original unicode should be encoded - def test_parse_json_whitespace_variations(self, agent_extractor): - """Test JSON parsing with various whitespace patterns""" - # Test JSON with different whitespace patterns + def test_parse_jsonl_whitespace_variations(self, agent_extractor): + """Test JSONL parsing with various whitespace patterns""" + # Test JSONL with different whitespace patterns test_cases = [ # Extra whitespace around code blocks - " ```json\n{\"test\": true}\n``` ", - # Tabs and mixed whitespace - "\t\t```json\n\t{\"test\": true}\n\t```\t", - # Multiple newlines - "\n\n\n```json\n\n{\"test\": true}\n\n```\n\n", - # JSON without code blocks but with whitespace - " {\"test\": true} ", - # Mixed line endings - "```json\r\n{\"test\": true}\r\n```", + ' ```json\n{"type": "definition", "entity": "test", "definition": "def"}\n``` ', + # Multiple newlines between lines + '{"type": "definition", "entity": "A", "definition": "def A"}\n\n\n{"type": "definition", "entity": "B", "definition": "def B"}', + # JSONL without code blocks but with whitespace + ' {"type": "definition", "entity": "test", "definition": "def"} ', ] - - for response in test_cases: - result = agent_extractor.parse_json(response) - assert result == {"test": True} - def test_parse_json_code_block_variations(self, agent_extractor): - """Test JSON parsing with different code block formats""" + for response in test_cases: + result = agent_extractor.parse_jsonl(response) + assert len(result) >= 1 + assert result[0].get("type") == "definition" + + def test_parse_jsonl_code_block_variations(self, agent_extractor): + """Test JSONL parsing with different code block formats""" test_cases = [ # Standard json code block - "```json\n{\"valid\": true}\n```", + '```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```', + # jsonl code block + '```jsonl\n{"type": "definition", "entity": "A", "definition": "def"}\n```', # Code block without language - "```\n{\"valid\": true}\n```", - # Uppercase JSON - "```JSON\n{\"valid\": true}\n```", - # Mixed case - "```Json\n{\"valid\": true}\n```", - # Multiple code blocks (should take first one) - "```json\n{\"first\": true}\n```\n```json\n{\"second\": true}\n```", - # Code block with extra content - "Here's the result:\n```json\n{\"valid\": true}\n```\nDone!", + '```\n{"type": "definition", "entity": "A", "definition": "def"}\n```', + # Code block with extra content before/after + 'Here\'s the result:\n```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```\nDone!', ] - + for i, response in enumerate(test_cases): - try: - result = agent_extractor.parse_json(response) - assert result.get("valid") == True or result.get("first") == True - except json.JSONDecodeError: - # Some cases may fail due to regex extraction issues - # This documents current behavior - the regex may not match all cases - print(f"Case {i} failed JSON parsing: {response[:50]}...") - pass + result = agent_extractor.parse_jsonl(response) + assert len(result) >= 1, f"Case {i} failed" + assert result[0].get("entity") == "A" - def test_parse_json_malformed_code_blocks(self, agent_extractor): - """Test JSON parsing with malformed code block formats""" - # These should still work by falling back to treating entire text as JSON - test_cases = [ - # Unclosed code block - "```json\n{\"test\": true}", - # No opening backticks - "{\"test\": true}\n```", - # Wrong number of backticks - "`json\n{\"test\": true}\n`", - # Nested backticks (should handle gracefully) - "```json\n{\"code\": \"```\", \"test\": true}\n```", - ] - - for response in test_cases: - try: - result = agent_extractor.parse_json(response) - assert "test" in result # Should successfully parse - except json.JSONDecodeError: - # This is also acceptable for malformed cases - pass + def test_parse_jsonl_truncation_resilience(self, agent_extractor): + """Test JSONL parsing with truncated responses""" + # Simulates LLM output being cut off mid-line + response = '''{"type": "definition", "entity": "Complete1", "definition": "Full definition"} +{"type": "definition", "entity": "Complete2", "definition": "Another full def"} +{"type": "definition", "entity": "Trunca''' - def test_parse_json_large_responses(self, agent_extractor): - """Test JSON parsing with very large responses""" - # Create a large JSON structure - large_data = { - "definitions": [ - { - "entity": f"Entity {i}", - "definition": f"Definition {i} " + "with more content " * 100 - } - for i in range(100) - ], - "relationships": [ - { - "subject": f"Subject {i}", - "predicate": f"predicate_{i}", - "object": f"Object {i}", - "object-entity": i % 2 == 0 - } - for i in range(50) - ] - } - - large_json_str = json.dumps(large_data) - response = f"```json\n{large_json_str}\n```" - - result = agent_extractor.parse_json(response) - - assert len(result["definitions"]) == 100 - assert len(result["relationships"]) == 50 - assert result["definitions"][0]["entity"] == "Entity 0" + result = agent_extractor.parse_jsonl(response) + + # Should get 2 valid objects, the truncated line is skipped + assert len(result) == 2 + assert result[0]["entity"] == "Complete1" + assert result[1]["entity"] == "Complete2" + + def test_parse_jsonl_large_responses(self, agent_extractor): + """Test JSONL parsing with very large responses""" + # Create a large JSONL response + lines = [] + for i in range(100): + lines.append(json.dumps({ + "type": "definition", + "entity": f"Entity {i}", + "definition": f"Definition {i} " + "with more content " * 100 + })) + for i in range(50): + lines.append(json.dumps({ + "type": "relationship", + "subject": f"Subject {i}", + "predicate": f"predicate_{i}", + "object": f"Object {i}", + "object-entity": i % 2 == 0 + })) + + response = f"```json\n{chr(10).join(lines)}\n```" + + result = agent_extractor.parse_jsonl(response) + + definitions = [r for r in result if r.get("type") == "definition"] + relationships = [r for r in result if r.get("type") == "relationship"] + + assert len(definitions) == 100 + assert len(relationships) == 50 + assert definitions[0]["entity"] == "Entity 0" def test_process_extraction_data_empty_metadata(self, agent_extractor): """Test processing with empty or minimal metadata""" # Test with None metadata - may not raise AttributeError depending on implementation try: - triples, contexts = agent_extractor.process_extraction_data( - {"definitions": [], "relationships": []}, - None - ) + triples, contexts = agent_extractor.process_extraction_data([], None) # If it doesn't raise, check the results assert len(triples) == 0 assert len(contexts) == 0 except (AttributeError, TypeError): # This is expected behavior when metadata is None pass - + # Test with metadata without ID metadata = Metadata(id=None, metadata=[]) - triples, contexts = agent_extractor.process_extraction_data( - {"definitions": [], "relationships": []}, - metadata - ) + triples, contexts = agent_extractor.process_extraction_data([], metadata) assert len(triples) == 0 assert len(contexts) == 0 - + # Test with metadata with empty string ID metadata = Metadata(id="", metadata=[]) - data = { - "definitions": [{"entity": "Test", "definition": "Test def"}], - "relationships": [] - } + data = [{"type": "definition", "entity": "Test", "definition": "Test def"}] triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should not create subject-of triples when ID is empty string - subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF] + subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF] assert len(subject_of_triples) == 0 def test_process_extraction_data_special_entity_names(self, agent_extractor): """Test processing with special characters in entity names""" metadata = Metadata(id="doc123", metadata=[]) - + special_entities = [ "Entity with spaces", "Entity & Co.", @@ -237,71 +207,62 @@ class TestAgentKgExtractionEdgeCases: "Quotes: \"test\"", "Parentheses: (test)", ] - - data = { - "definitions": [ - {"entity": entity, "definition": f"Definition for {entity}"} - for entity in special_entities - ], - "relationships": [] - } - + + data = [ + {"type": "definition", "entity": entity, "definition": f"Definition for {entity}"} + for entity in special_entities + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Verify all entities were processed assert len(contexts) == len(special_entities) - + # Verify URIs were properly encoded for i, entity in enumerate(special_entities): expected_uri = f"{TRUSTGRAPH_ENTITIES}{urllib.parse.quote(entity)}" - assert contexts[i].entity.value == expected_uri + assert contexts[i].entity.iri == expected_uri def test_process_extraction_data_very_long_definitions(self, agent_extractor): """Test processing with very long entity definitions""" metadata = Metadata(id="doc123", metadata=[]) - + # Create very long definition long_definition = "This is a very long definition. " * 1000 - - data = { - "definitions": [ - {"entity": "Test Entity", "definition": long_definition} - ], - "relationships": [] - } - + + data = [ + {"type": "definition", "entity": "Test Entity", "definition": long_definition} + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should handle long definitions without issues assert len(contexts) == 1 assert contexts[0].context == long_definition - + # Find definition triple - def_triple = next((t for t in triples if t.p.value == DEFINITION), None) + def_triple = next((t for t in triples if t.p.iri == DEFINITION), None) assert def_triple is not None assert def_triple.o.value == long_definition def test_process_extraction_data_duplicate_entities(self, agent_extractor): """Test processing with duplicate entity names""" metadata = Metadata(id="doc123", metadata=[]) - - data = { - "definitions": [ - {"entity": "Machine Learning", "definition": "First definition"}, - {"entity": "Machine Learning", "definition": "Second definition"}, # Duplicate - {"entity": "AI", "definition": "AI definition"}, - {"entity": "AI", "definition": "Another AI definition"}, # Duplicate - ], - "relationships": [] - } - + + data = [ + {"type": "definition", "entity": "Machine Learning", "definition": "First definition"}, + {"type": "definition", "entity": "Machine Learning", "definition": "Second definition"}, # Duplicate + {"type": "definition", "entity": "AI", "definition": "AI definition"}, + {"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should process all entries (including duplicates) assert len(contexts) == 4 - + # Check that both definitions for "Machine Learning" are present - ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.value] + ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.iri] assert len(ml_contexts) == 2 assert ml_contexts[0].context == "First definition" assert ml_contexts[1].context == "Second definition" @@ -309,49 +270,44 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_empty_strings(self, agent_extractor): """Test processing with empty strings in data""" metadata = Metadata(id="doc123", metadata=[]) - - data = { - "definitions": [ - {"entity": "", "definition": "Definition for empty entity"}, - {"entity": "Valid Entity", "definition": ""}, - {"entity": " ", "definition": " "}, # Whitespace only - ], - "relationships": [ - {"subject": "", "predicate": "test", "object": "test", "object-entity": True}, - {"subject": "test", "predicate": "", "object": "test", "object-entity": True}, - {"subject": "test", "predicate": "test", "object": "", "object-entity": True}, - ] - } - + + data = [ + {"type": "definition", "entity": "", "definition": "Definition for empty entity"}, + {"type": "definition", "entity": "Valid Entity", "definition": ""}, + {"type": "definition", "entity": " ", "definition": " "}, # Whitespace only + {"type": "relationship", "subject": "", "predicate": "test", "object": "test", "object-entity": True}, + {"type": "relationship", "subject": "test", "predicate": "", "object": "test", "object-entity": True}, + {"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True}, + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should handle empty strings by creating URIs (even if empty) assert len(contexts) == 3 - + # Empty entity should create empty URI after encoding - empty_entity_context = next((ec for ec in contexts if ec.entity.value == TRUSTGRAPH_ENTITIES), None) + empty_entity_context = next((ec for ec in contexts if ec.entity.iri == TRUSTGRAPH_ENTITIES), None) assert empty_entity_context is not None def test_process_extraction_data_nested_json_in_strings(self, agent_extractor): """Test processing when definitions contain JSON-like strings""" metadata = Metadata(id="doc123", metadata=[]) - - data = { - "definitions": [ - { - "entity": "JSON Entity", - "definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}' - }, - { - "entity": "Array Entity", - "definition": 'Contains array: [1, 2, 3, "string"]' - } - ], - "relationships": [] - } - + + data = [ + { + "type": "definition", + "entity": "JSON Entity", + "definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}' + }, + { + "type": "definition", + "entity": "Array Entity", + "definition": 'Contains array: [1, 2, 3, "string"]' + } + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should handle JSON strings in definitions without parsing them assert len(contexts) == 2 assert '{"key": "value"' in contexts[0].context @@ -360,32 +316,29 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor): """Test processing with various boolean values for object-entity""" metadata = Metadata(id="doc123", metadata=[]) - - data = { - "definitions": [], - "relationships": [ - # Explicit True - {"subject": "A", "predicate": "rel1", "object": "B", "object-entity": True}, - # Explicit False - {"subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False}, - # Missing object-entity (should default to True based on code) - {"subject": "A", "predicate": "rel3", "object": "C"}, - # String "true" (should be treated as truthy) - {"subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"}, - # String "false" (should be treated as truthy in Python) - {"subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"}, - # Number 0 (falsy) - {"subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0}, - # Number 1 (truthy) - {"subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1}, - ] - } - + + data = [ + # Explicit True + {"type": "relationship", "subject": "A", "predicate": "rel1", "object": "B", "object-entity": True}, + # Explicit False + {"type": "relationship", "subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False}, + # Missing object-entity (should default to True based on code) + {"type": "relationship", "subject": "A", "predicate": "rel3", "object": "C"}, + # String "true" (should be treated as truthy) + {"type": "relationship", "subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"}, + # String "false" (should be treated as truthy in Python) + {"type": "relationship", "subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"}, + # Number 0 (falsy) + {"type": "relationship", "subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0}, + # Number 1 (truthy) + {"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1}, + ] + triples, contexts = agent_extractor.process_extraction_data(data, metadata) - + # Should process all relationships # Note: The current implementation has some logic issues that these tests document - assert len([t for t in triples if t.p.value != RDF_LABEL and t.p.value != SUBJECT_OF]) >= 7 + assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7 @pytest.mark.asyncio async def test_emit_empty_collections(self, agent_extractor): @@ -437,41 +390,40 @@ class TestAgentKgExtractionEdgeCases: def test_process_extraction_data_performance_large_dataset(self, agent_extractor): """Test performance with large extraction datasets""" metadata = Metadata(id="large-doc", metadata=[]) - - # Create large dataset + + # Create large dataset in JSONL format num_definitions = 1000 num_relationships = 2000 - - large_data = { - "definitions": [ - { - "entity": f"Entity_{i:04d}", - "definition": f"Definition for entity {i} with some detailed explanation." - } - for i in range(num_definitions) - ], - "relationships": [ - { - "subject": f"Entity_{i % num_definitions:04d}", - "predicate": f"predicate_{i % 10}", - "object": f"Entity_{(i + 1) % num_definitions:04d}", - "object-entity": True - } - for i in range(num_relationships) - ] - } - + + large_data = [ + { + "type": "definition", + "entity": f"Entity_{i:04d}", + "definition": f"Definition for entity {i} with some detailed explanation." + } + for i in range(num_definitions) + ] + [ + { + "type": "relationship", + "subject": f"Entity_{i % num_definitions:04d}", + "predicate": f"predicate_{i % 10}", + "object": f"Entity_{(i + 1) % num_definitions:04d}", + "object-entity": True + } + for i in range(num_relationships) + ] + import time start_time = time.time() - + triples, contexts = agent_extractor.process_extraction_data(large_data, metadata) - + end_time = time.time() processing_time = end_time - start_time - + # Should complete within reasonable time (adjust threshold as needed) assert processing_time < 10.0 # 10 seconds threshold - + # Verify results assert len(contexts) == num_definitions # Triples include labels, definitions, relationships, and subject-of relations diff --git a/tests/unit/test_knowledge_graph/test_graph_validation.py b/tests/unit/test_knowledge_graph/test_graph_validation.py index fd6e12cf..e9e2750b 100644 --- a/tests/unit/test_knowledge_graph/test_graph_validation.py +++ b/tests/unit/test_knowledge_graph/test_graph_validation.py @@ -7,7 +7,7 @@ processing graph structures, and performing graph operations. import pytest from unittest.mock import Mock -from .conftest import Triple, Value, Metadata +from .conftest import Triple, Metadata from collections import defaultdict, deque diff --git a/tests/unit/test_knowledge_graph/test_object_validation.py b/tests/unit/test_knowledge_graph/test_object_validation.py index b2ac28aa..47d2e4d7 100644 --- a/tests/unit/test_knowledge_graph/test_object_validation.py +++ b/tests/unit/test_knowledge_graph/test_object_validation.py @@ -76,7 +76,7 @@ def cities_schema(): def validator(): """Create a mock processor with just the validation method""" from unittest.mock import MagicMock - from trustgraph.extract.kg.objects.processor import Processor + from trustgraph.extract.kg.rows.processor import Processor # Create a mock processor mock_processor = MagicMock() diff --git a/tests/unit/test_knowledge_graph/test_triple_construction.py b/tests/unit/test_knowledge_graph/test_triple_construction.py index b1cf1274..10bae2e7 100644 --- a/tests/unit/test_knowledge_graph/test_triple_construction.py +++ b/tests/unit/test_knowledge_graph/test_triple_construction.py @@ -2,13 +2,13 @@ Unit tests for triple construction logic Tests the core business logic for constructing RDF triples from extracted -entities and relationships, including URI generation, Value object creation, +entities and relationships, including URI generation, Term object creation, and triple validation. """ import pytest from unittest.mock import Mock -from .conftest import Triple, Triples, Value, Metadata +from .conftest import Triple, Triples, Term, Metadata, IRI, LITERAL import re import hashlib @@ -48,80 +48,82 @@ class TestTripleConstructionLogic: generated_uri = generate_uri(text, entity_type) assert generated_uri == expected_uri, f"URI generation failed for '{text}'" - def test_value_object_creation(self): - """Test creation of Value objects for subjects, predicates, and objects""" + def test_term_object_creation(self): + """Test creation of Term objects for subjects, predicates, and objects""" # Arrange - def create_value_object(text, is_uri, value_type=""): - return Value( - value=text, - is_uri=is_uri, - type=value_type - ) - + def create_term_object(text, is_uri, datatype=""): + if is_uri: + return Term(type=IRI, iri=text) + else: + return Term(type=LITERAL, value=text, datatype=datatype if datatype else None) + test_cases = [ ("http://trustgraph.ai/kg/person/john-smith", True, ""), ("John Smith", False, "string"), ("42", False, "integer"), ("http://schema.org/worksFor", True, "") ] - + # Act & Assert - for value_text, is_uri, value_type in test_cases: - value_obj = create_value_object(value_text, is_uri, value_type) - - assert isinstance(value_obj, Value) - assert value_obj.value == value_text - assert value_obj.is_uri == is_uri - assert value_obj.type == value_type + for value_text, is_uri, datatype in test_cases: + term_obj = create_term_object(value_text, is_uri, datatype) + + assert isinstance(term_obj, Term) + if is_uri: + assert term_obj.type == IRI + assert term_obj.iri == value_text + else: + assert term_obj.type == LITERAL + assert term_obj.value == value_text def test_triple_construction_from_relationship(self): """Test constructing Triple objects from relationships""" # Arrange relationship = { "subject": "John Smith", - "predicate": "works_for", + "predicate": "works_for", "object": "OpenAI", "subject_type": "PERSON", "object_type": "ORG" } - + def construct_triple(relationship, uri_base="http://trustgraph.ai/kg"): # Generate URIs subject_uri = f"{uri_base}/person/{relationship['subject'].lower().replace(' ', '-')}" object_uri = f"{uri_base}/org/{relationship['object'].lower().replace(' ', '-')}" - + # Map predicate to schema.org URI predicate_mappings = { "works_for": "http://schema.org/worksFor", "located_in": "http://schema.org/location", "developed": "http://schema.org/creator" } - predicate_uri = predicate_mappings.get(relationship["predicate"], + predicate_uri = predicate_mappings.get(relationship["predicate"], f"{uri_base}/predicate/{relationship['predicate']}") - - # Create Value objects - subject_value = Value(value=subject_uri, is_uri=True, type="") - predicate_value = Value(value=predicate_uri, is_uri=True, type="") - object_value = Value(value=object_uri, is_uri=True, type="") - + + # Create Term objects + subject_term = Term(type=IRI, iri=subject_uri) + predicate_term = Term(type=IRI, iri=predicate_uri) + object_term = Term(type=IRI, iri=object_uri) + # Create Triple return Triple( - s=subject_value, - p=predicate_value, - o=object_value + s=subject_term, + p=predicate_term, + o=object_term ) - + # Act triple = construct_triple(relationship) - + # Assert assert isinstance(triple, Triple) - assert triple.s.value == "http://trustgraph.ai/kg/person/john-smith" - assert triple.s.is_uri is True - assert triple.p.value == "http://schema.org/worksFor" - assert triple.p.is_uri is True - assert triple.o.value == "http://trustgraph.ai/kg/org/openai" - assert triple.o.is_uri is True + assert triple.s.iri == "http://trustgraph.ai/kg/person/john-smith" + assert triple.s.type == IRI + assert triple.p.iri == "http://schema.org/worksFor" + assert triple.p.type == IRI + assert triple.o.iri == "http://trustgraph.ai/kg/org/openai" + assert triple.o.type == IRI def test_literal_value_handling(self): """Test handling of literal values vs URI values""" @@ -132,10 +134,10 @@ class TestTripleConstructionLogic: ("John Smith", "email", "john@example.com", False), # Literal email ("John Smith", "worksFor", "http://trustgraph.ai/kg/org/openai", True) # URI reference ] - + def create_triple_with_literal(subject_uri, predicate, object_value, object_is_uri): - subject_val = Value(value=subject_uri, is_uri=True, type="") - + subject_term = Term(type=IRI, iri=subject_uri) + # Determine predicate URI predicate_mappings = { "name": "http://schema.org/name", @@ -144,32 +146,37 @@ class TestTripleConstructionLogic: "worksFor": "http://schema.org/worksFor" } predicate_uri = predicate_mappings.get(predicate, f"http://trustgraph.ai/kg/predicate/{predicate}") - predicate_val = Value(value=predicate_uri, is_uri=True, type="") - - # Create object value with appropriate type - object_type = "" - if not object_is_uri: + predicate_term = Term(type=IRI, iri=predicate_uri) + + # Create object term with appropriate type + if object_is_uri: + object_term = Term(type=IRI, iri=object_value) + else: + datatype = None if predicate == "age": - object_type = "integer" + datatype = "integer" elif predicate in ["name", "email"]: - object_type = "string" - - object_val = Value(value=object_value, is_uri=object_is_uri, type=object_type) - - return Triple(s=subject_val, p=predicate_val, o=object_val) - + datatype = "string" + object_term = Term(type=LITERAL, value=object_value, datatype=datatype) + + return Triple(s=subject_term, p=predicate_term, o=object_term) + # Act & Assert for subject_uri, predicate, object_value, object_is_uri in test_data: subject_full_uri = "http://trustgraph.ai/kg/person/john-smith" triple = create_triple_with_literal(subject_full_uri, predicate, object_value, object_is_uri) - - assert triple.o.is_uri == object_is_uri - assert triple.o.value == object_value - + + if object_is_uri: + assert triple.o.type == IRI + assert triple.o.iri == object_value + else: + assert triple.o.type == LITERAL + assert triple.o.value == object_value + if predicate == "age": - assert triple.o.type == "integer" + assert triple.o.datatype == "integer" elif predicate in ["name", "email"]: - assert triple.o.type == "string" + assert triple.o.datatype == "string" def test_namespace_management(self): """Test namespace prefix management and expansion""" @@ -216,63 +223,74 @@ class TestTripleConstructionLogic: def test_triple_validation(self): """Test triple validation rules""" # Arrange + def get_term_value(term): + """Extract value from a Term""" + if term.type == IRI: + return term.iri + else: + return term.value + def validate_triple(triple): errors = [] - + # Check required components - if not triple.s or not triple.s.value: + s_val = get_term_value(triple.s) if triple.s else None + p_val = get_term_value(triple.p) if triple.p else None + o_val = get_term_value(triple.o) if triple.o else None + + if not triple.s or not s_val: errors.append("Missing or empty subject") - - if not triple.p or not triple.p.value: + + if not triple.p or not p_val: errors.append("Missing or empty predicate") - - if not triple.o or not triple.o.value: + + if not triple.o or not o_val: errors.append("Missing or empty object") - + # Check URI validity for URI values uri_pattern = r'^https?://[^\s/$.?#].[^\s]*$' - - if triple.s.is_uri and not re.match(uri_pattern, triple.s.value): + + if triple.s.type == IRI and not re.match(uri_pattern, triple.s.iri or ""): errors.append("Invalid subject URI format") - - if triple.p.is_uri and not re.match(uri_pattern, triple.p.value): + + if triple.p.type == IRI and not re.match(uri_pattern, triple.p.iri or ""): errors.append("Invalid predicate URI format") - - if triple.o.is_uri and not re.match(uri_pattern, triple.o.value): + + if triple.o.type == IRI and not re.match(uri_pattern, triple.o.iri or ""): errors.append("Invalid object URI format") - + # Predicates should typically be URIs - if not triple.p.is_uri: + if triple.p.type != IRI: errors.append("Predicate should be a URI") - + return len(errors) == 0, errors - + # Test valid triple valid_triple = Triple( - s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""), - p=Value(value="http://schema.org/name", is_uri=True, type=""), - o=Value(value="John Smith", is_uri=False, type="string") + s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"), + p=Term(type=IRI, iri="http://schema.org/name"), + o=Term(type=LITERAL, value="John Smith", datatype="string") ) - + # Test invalid triples invalid_triples = [ - Triple(s=Value(value="", is_uri=True, type=""), - p=Value(value="http://schema.org/name", is_uri=True, type=""), - o=Value(value="John", is_uri=False, type="")), # Empty subject - - Triple(s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""), - p=Value(value="name", is_uri=False, type=""), # Non-URI predicate - o=Value(value="John", is_uri=False, type="")), - - Triple(s=Value(value="invalid-uri", is_uri=True, type=""), - p=Value(value="http://schema.org/name", is_uri=True, type=""), - o=Value(value="John", is_uri=False, type="")) # Invalid URI format + Triple(s=Term(type=IRI, iri=""), + p=Term(type=IRI, iri="http://schema.org/name"), + o=Term(type=LITERAL, value="John")), # Empty subject + + Triple(s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"), + p=Term(type=LITERAL, value="name"), # Non-URI predicate + o=Term(type=LITERAL, value="John")), + + Triple(s=Term(type=IRI, iri="invalid-uri"), + p=Term(type=IRI, iri="http://schema.org/name"), + o=Term(type=LITERAL, value="John")) # Invalid URI format ] - + # Act & Assert is_valid, errors = validate_triple(valid_triple) assert is_valid, f"Valid triple failed validation: {errors}" - + for invalid_triple in invalid_triples: is_valid, errors = validate_triple(invalid_triple) assert not is_valid, f"Invalid triple passed validation: {invalid_triple}" @@ -286,97 +304,97 @@ class TestTripleConstructionLogic: {"text": "OpenAI", "type": "ORG"}, {"text": "San Francisco", "type": "PLACE"} ] - + relationships = [ {"subject": "John Smith", "predicate": "works_for", "object": "OpenAI"}, {"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"} ] - + def construct_triple_batch(entities, relationships, document_id="doc-1"): triples = [] - + # Create type triples for entities for entity in entities: entity_uri = f"http://trustgraph.ai/kg/{entity['type'].lower()}/{entity['text'].lower().replace(' ', '-')}" type_uri = f"http://trustgraph.ai/kg/type/{entity['type']}" - + type_triple = Triple( - s=Value(value=entity_uri, is_uri=True, type=""), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True, type=""), - o=Value(value=type_uri, is_uri=True, type="") + s=Term(type=IRI, iri=entity_uri), + p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + o=Term(type=IRI, iri=type_uri) ) triples.append(type_triple) - + # Create relationship triples for rel in relationships: subject_uri = f"http://trustgraph.ai/kg/entity/{rel['subject'].lower().replace(' ', '-')}" object_uri = f"http://trustgraph.ai/kg/entity/{rel['object'].lower().replace(' ', '-')}" predicate_uri = f"http://schema.org/{rel['predicate'].replace('_', '')}" - + rel_triple = Triple( - s=Value(value=subject_uri, is_uri=True, type=""), - p=Value(value=predicate_uri, is_uri=True, type=""), - o=Value(value=object_uri, is_uri=True, type="") + s=Term(type=IRI, iri=subject_uri), + p=Term(type=IRI, iri=predicate_uri), + o=Term(type=IRI, iri=object_uri) ) triples.append(rel_triple) - + return triples - + # Act triples = construct_triple_batch(entities, relationships) - + # Assert assert len(triples) == len(entities) + len(relationships) # Type triples + relationship triples - + # Check that all triples are valid Triple objects for triple in triples: assert isinstance(triple, Triple) - assert triple.s.value != "" - assert triple.p.value != "" - assert triple.o.value != "" + assert triple.s.iri != "" + assert triple.p.iri != "" + assert triple.o.iri != "" def test_triples_batch_object_creation(self): """Test creating Triples batch objects with metadata""" # Arrange sample_triples = [ Triple( - s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""), - p=Value(value="http://schema.org/name", is_uri=True, type=""), - o=Value(value="John Smith", is_uri=False, type="string") + s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"), + p=Term(type=IRI, iri="http://schema.org/name"), + o=Term(type=LITERAL, value="John Smith", datatype="string") ), Triple( - s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""), - p=Value(value="http://schema.org/worksFor", is_uri=True, type=""), - o=Value(value="http://trustgraph.ai/kg/org/openai", is_uri=True, type="") + s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"), + p=Term(type=IRI, iri="http://schema.org/worksFor"), + o=Term(type=IRI, iri="http://trustgraph.ai/kg/org/openai") ) ] - + metadata = Metadata( id="test-doc-123", - user="test_user", + user="test_user", collection="test_collection", metadata=[] ) - + # Act triples_batch = Triples( metadata=metadata, triples=sample_triples ) - + # Assert assert isinstance(triples_batch, Triples) assert triples_batch.metadata.id == "test-doc-123" assert triples_batch.metadata.user == "test_user" assert triples_batch.metadata.collection == "test_collection" assert len(triples_batch.triples) == 2 - + # Check that triples are properly embedded for triple in triples_batch.triples: assert isinstance(triple, Triple) - assert isinstance(triple.s, Value) - assert isinstance(triple.p, Value) - assert isinstance(triple.o, Value) + assert isinstance(triple.s, Term) + assert isinstance(triple.p, Term) + assert isinstance(triple.o, Term) def test_uri_collision_handling(self): """Test handling of URI collisions and duplicate detection""" diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 026791d0..3e73ab9c 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -339,7 +339,250 @@ class TestPromptManager: """Test PromptManager with minimal configuration""" pm = PromptManager() pm.load_config({}) # Empty config - + assert pm.config.system_template == "Be helpful." # Default system assert pm.terms == {} # Default empty terms - assert len(pm.prompts) == 0 \ No newline at end of file + assert len(pm.prompts) == 0 + + +@pytest.mark.unit +class TestPromptManagerJsonl: + """Unit tests for PromptManager JSONL functionality""" + + @pytest.fixture + def jsonl_config(self): + """Configuration with JSONL response type prompts""" + return { + "system": json.dumps("You are an extraction assistant."), + "template-index": json.dumps(["extract_simple", "extract_with_schema", "extract_mixed"]), + "template.extract_simple": json.dumps({ + "prompt": "Extract entities from: {{ text }}", + "response-type": "jsonl" + }), + "template.extract_with_schema": json.dumps({ + "prompt": "Extract definitions from: {{ text }}", + "response-type": "jsonl", + "schema": { + "type": "object", + "properties": { + "entity": {"type": "string"}, + "definition": {"type": "string"} + }, + "required": ["entity", "definition"] + } + }), + "template.extract_mixed": json.dumps({ + "prompt": "Extract knowledge from: {{ text }}", + "response-type": "jsonl", + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": {"const": "definition"}, + "entity": {"type": "string"}, + "definition": {"type": "string"} + }, + "required": ["type", "entity", "definition"] + }, + { + "type": "object", + "properties": { + "type": {"const": "relationship"}, + "subject": {"type": "string"}, + "predicate": {"type": "string"}, + "object": {"type": "string"} + }, + "required": ["type", "subject", "predicate", "object"] + } + ] + } + }) + } + + @pytest.fixture + def prompt_manager(self, jsonl_config): + """Create a PromptManager with JSONL configuration""" + pm = PromptManager() + pm.load_config(jsonl_config) + return pm + + def test_parse_jsonl_basic(self, prompt_manager): + """Test basic JSONL parsing""" + text = '{"entity": "cat", "definition": "A small furry animal"}\n{"entity": "dog", "definition": "A loyal pet"}' + + result = prompt_manager.parse_jsonl(text) + + assert len(result) == 2 + assert result[0]["entity"] == "cat" + assert result[1]["entity"] == "dog" + + def test_parse_jsonl_with_empty_lines(self, prompt_manager): + """Test JSONL parsing skips empty lines""" + text = '{"entity": "cat"}\n\n\n{"entity": "dog"}\n' + + result = prompt_manager.parse_jsonl(text) + + assert len(result) == 2 + + def test_parse_jsonl_with_markdown_fences(self, prompt_manager): + """Test JSONL parsing strips markdown code fences""" + text = '''```json +{"entity": "cat", "definition": "A furry animal"} +{"entity": "dog", "definition": "A loyal pet"} +```''' + + result = prompt_manager.parse_jsonl(text) + + assert len(result) == 2 + assert result[0]["entity"] == "cat" + assert result[1]["entity"] == "dog" + + def test_parse_jsonl_with_jsonl_fence(self, prompt_manager): + """Test JSONL parsing strips jsonl-marked code fences""" + text = '''```jsonl +{"entity": "cat"} +{"entity": "dog"} +```''' + + result = prompt_manager.parse_jsonl(text) + + assert len(result) == 2 + + def test_parse_jsonl_truncation_resilience(self, prompt_manager): + """Test JSONL parsing handles truncated final line""" + text = '{"entity": "cat", "definition": "Complete"}\n{"entity": "dog", "defi' + + result = prompt_manager.parse_jsonl(text) + + # Should get the first valid object, skip the truncated one + assert len(result) == 1 + assert result[0]["entity"] == "cat" + + def test_parse_jsonl_invalid_lines_skipped(self, prompt_manager): + """Test JSONL parsing skips invalid JSON lines""" + text = '''{"entity": "valid1"} +not json at all +{"entity": "valid2"} +{broken json +{"entity": "valid3"}''' + + result = prompt_manager.parse_jsonl(text) + + assert len(result) == 3 + assert result[0]["entity"] == "valid1" + assert result[1]["entity"] == "valid2" + assert result[2]["entity"] == "valid3" + + def test_parse_jsonl_empty_input(self, prompt_manager): + """Test JSONL parsing with empty input""" + result = prompt_manager.parse_jsonl("") + assert result == [] + + result = prompt_manager.parse_jsonl("\n\n\n") + assert result == [] + + @pytest.mark.asyncio + async def test_invoke_jsonl_response(self, prompt_manager): + """Test invoking a prompt with JSONL response""" + mock_llm = AsyncMock() + mock_llm.return_value = '{"entity": "photosynthesis", "definition": "Plant process"}\n{"entity": "mitosis", "definition": "Cell division"}' + + result = await prompt_manager.invoke( + "extract_simple", + {"text": "Biology text"}, + mock_llm + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["entity"] == "photosynthesis" + assert result[1]["entity"] == "mitosis" + + @pytest.mark.asyncio + async def test_invoke_jsonl_with_schema_validation(self, prompt_manager): + """Test JSONL response with schema validation""" + mock_llm = AsyncMock() + mock_llm.return_value = '{"entity": "cat", "definition": "A pet"}\n{"entity": "dog", "definition": "Another pet"}' + + result = await prompt_manager.invoke( + "extract_with_schema", + {"text": "Animal text"}, + mock_llm + ) + + assert len(result) == 2 + assert all("entity" in obj and "definition" in obj for obj in result) + + @pytest.mark.asyncio + async def test_invoke_jsonl_schema_filters_invalid(self, prompt_manager): + """Test JSONL schema validation filters out invalid objects""" + mock_llm = AsyncMock() + # Second object is missing required 'definition' field + mock_llm.return_value = '{"entity": "valid", "definition": "Has both fields"}\n{"entity": "invalid_missing_definition"}\n{"entity": "also_valid", "definition": "Complete"}' + + result = await prompt_manager.invoke( + "extract_with_schema", + {"text": "Test text"}, + mock_llm + ) + + # Only the two valid objects should be returned + assert len(result) == 2 + assert result[0]["entity"] == "valid" + assert result[1]["entity"] == "also_valid" + + @pytest.mark.asyncio + async def test_invoke_jsonl_mixed_types(self, prompt_manager): + """Test JSONL with discriminated union schema (oneOf)""" + mock_llm = AsyncMock() + mock_llm.return_value = '''{"type": "definition", "entity": "DNA", "definition": "Genetic material"} +{"type": "relationship", "subject": "DNA", "predicate": "found_in", "object": "nucleus"} +{"type": "definition", "entity": "RNA", "definition": "Messenger molecule"}''' + + result = await prompt_manager.invoke( + "extract_mixed", + {"text": "Biology text"}, + mock_llm + ) + + assert len(result) == 3 + + # Check definitions + definitions = [r for r in result if r.get("type") == "definition"] + assert len(definitions) == 2 + + # Check relationships + relationships = [r for r in result if r.get("type") == "relationship"] + assert len(relationships) == 1 + assert relationships[0]["subject"] == "DNA" + + @pytest.mark.asyncio + async def test_invoke_jsonl_empty_result(self, prompt_manager): + """Test JSONL response that yields no valid objects""" + mock_llm = AsyncMock() + mock_llm.return_value = "No JSON here at all" + + result = await prompt_manager.invoke( + "extract_simple", + {"text": "Test"}, + mock_llm + ) + + assert result == [] + + @pytest.mark.asyncio + async def test_invoke_jsonl_without_schema(self, prompt_manager): + """Test JSONL response without schema validation""" + mock_llm = AsyncMock() + mock_llm.return_value = '{"any": "structure"}\n{"completely": "different"}' + + result = await prompt_manager.invoke( + "extract_simple", + {"text": "Test"}, + mock_llm + ) + + assert len(result) == 2 + assert result[0] == {"any": "structure"} + assert result[1] == {"completely": "different"} \ No newline at end of file diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index f86ae3da..80443a0c 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -167,7 +167,7 @@ class TestFlowClient: expected_methods = [ 'text_completion', 'agent', 'graph_rag', 'document_rag', 'graph_embeddings_query', 'embeddings', 'prompt', - 'triples_query', 'objects_query' + 'triples_query', 'rows_query' ] for method in expected_methods: @@ -216,7 +216,7 @@ class TestSocketClient: expected_methods = [ 'agent', 'text_completion', 'graph_rag', 'document_rag', 'prompt', 'graph_embeddings_query', 'embeddings', - 'triples_query', 'objects_query', 'mcp_tool' + 'triples_query', 'rows_query', 'mcp_tool' ] for method in expected_methods: @@ -243,7 +243,7 @@ class TestBulkClient: 'import_graph_embeddings', 'import_document_embeddings', 'import_entity_contexts', - 'import_objects' + 'import_rows' ] for method in import_methods: diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index ebacfaaf..21b6e1bf 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.graph_embeddings.milvus.service import Processor -from trustgraph.schema import Value, GraphEmbeddingsRequest +from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL class TestMilvusGraphEmbeddingsQueryProcessor: @@ -68,50 +68,50 @@ class TestMilvusGraphEmbeddingsQueryProcessor: def test_create_value_with_http_uri(self, processor): """Test create_value with HTTP URI""" result = processor.create_value("http://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI def test_create_value_with_https_uri(self, processor): """Test create_value with HTTPS URI""" result = processor.create_value("https://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "https://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "https://example.com/resource" + assert result.type == IRI def test_create_value_with_literal(self, processor): """Test create_value with literal value""" result = processor.create_value("just a literal string") - assert isinstance(result, Value) + assert isinstance(result, Term) assert result.value == "just a literal string" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_empty_string(self, processor): """Test create_value with empty string""" result = processor.create_value("") - assert isinstance(result, Value) + assert isinstance(result, Term) assert result.value == "" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_partial_uri(self, processor): """Test create_value with string that looks like URI but isn't complete""" result = processor.create_value("http") - assert isinstance(result, Value) + assert isinstance(result, Term) assert result.value == "http" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_ftp_uri(self, processor): """Test create_value with FTP URI (should not be detected as URI)""" result = processor.create_value("ftp://example.com/file") - assert isinstance(result, Value) + assert isinstance(result, Term) assert result.value == "ftp://example.com/file" - assert result.is_uri is False + assert result.type == LITERAL @pytest.mark.asyncio async def test_query_graph_embeddings_single_vector(self, processor): @@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 ) - # Verify results are converted to Value objects + # Verify results are converted to Term objects assert len(result) == 3 - assert isinstance(result[0], Value) - assert result[0].value == "http://example.com/entity1" - assert result[0].is_uri is True - assert isinstance(result[1], Value) - assert result[1].value == "http://example.com/entity2" - assert result[1].is_uri is True - assert isinstance(result[2], Value) + assert isinstance(result[0], Term) + assert result[0].iri == "http://example.com/entity1" + assert result[0].type == IRI + assert isinstance(result[1], Term) + assert result[1].iri == "http://example.com/entity2" + assert result[1].type == IRI + assert isinstance(result[2], Term) assert result[2].value == "literal entity" - assert result[2].is_uri is False + assert result[2].type == LITERAL @pytest.mark.asyncio async def test_query_graph_embeddings_multiple_vectors(self, processor): @@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify results are deduplicated and limited assert len(result) == 3 - entity_values = [r.value for r in result] + entity_values = [r.iri if r.type == IRI else r.value for r in result] assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values assert "http://example.com/entity3" in entity_values @@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify duplicates are removed assert len(result) == 3 - entity_values = [r.value for r in result] + entity_values = [r.iri if r.type == IRI else r.value for r in result] assert len(set(entity_values)) == 3 # All unique assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values @@ -346,14 +346,14 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 4 # Check URI entities - uri_results = [r for r in result if r.is_uri] + uri_results = [r for r in result if r.type == IRI] assert len(uri_results) == 2 - uri_values = [r.value for r in uri_results] + uri_values = [r.iri for r in uri_results] assert "http://example.com/uri_entity" in uri_values assert "https://example.com/another_uri" in uri_values # Check literal entities - literal_results = [r for r in result if not r.is_uri] + literal_results = [r for r in result if not r.type == IRI] assert len(literal_results) == 2 literal_values = [r.value for r in literal_results] assert "literal entity text" in literal_values @@ -486,7 +486,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify results from all dimensions assert len(result) == 3 - entity_values = [r.value for r in result] + entity_values = [r.iri if r.type == IRI else r.value for r in result] assert "entity_2d" in entity_values assert "entity_4d" in entity_values assert "entity_3d" in entity_values \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 0c13e9c9..1b243113 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) from trustgraph.query.graph_embeddings.pinecone.service import Processor -from trustgraph.schema import Value +from trustgraph.schema import Term, IRI, LITERAL class TestPineconeGraphEmbeddingsQueryProcessor: @@ -105,27 +105,27 @@ class TestPineconeGraphEmbeddingsQueryProcessor: uri_entity = "http://example.org/entity" value = processor.create_value(uri_entity) - assert isinstance(value, Value) + assert isinstance(value, Term) assert value.value == uri_entity - assert value.is_uri == True + assert value.type == IRI def test_create_value_https_uri(self, processor): """Test create_value method for HTTPS URI entities""" uri_entity = "https://example.org/entity" value = processor.create_value(uri_entity) - assert isinstance(value, Value) + assert isinstance(value, Term) assert value.value == uri_entity - assert value.is_uri == True + assert value.type == IRI def test_create_value_literal(self, processor): """Test create_value method for literal entities""" literal_entity = "literal_entity" value = processor.create_value(literal_entity) - assert isinstance(value, Value) + assert isinstance(value, Term) assert value.value == literal_entity - assert value.is_uri == False + assert value.type == LITERAL @pytest.mark.asyncio async def test_query_graph_embeddings_single_vector(self, processor): @@ -165,11 +165,11 @@ class TestPineconeGraphEmbeddingsQueryProcessor: # Verify results assert len(entities) == 3 assert entities[0].value == 'http://example.org/entity1' - assert entities[0].is_uri == True + assert entities[0].type == IRI assert entities[1].value == 'entity2' - assert entities[1].is_uri == False + assert entities[1].type == LITERAL assert entities[2].value == 'http://example.org/entity3' - assert entities[2].is_uri == True + assert entities[2].type == IRI @pytest.mark.asyncio async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index ab22c9df..1760c4c1 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.graph_embeddings.qdrant.service import Processor +from trustgraph.schema import IRI, LITERAL class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -85,10 +86,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): value = processor.create_value('http://example.com/entity') # Assert - assert hasattr(value, 'value') - assert value.value == 'http://example.com/entity' - assert hasattr(value, 'is_uri') - assert value.is_uri == True + assert hasattr(value, 'iri') + assert value.iri == 'http://example.com/entity' + assert hasattr(value, 'type') + assert value.type == IRI @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -109,10 +110,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): value = processor.create_value('https://secure.example.com/entity') # Assert - assert hasattr(value, 'value') - assert value.value == 'https://secure.example.com/entity' - assert hasattr(value, 'is_uri') - assert value.is_uri == True + assert hasattr(value, 'iri') + assert value.iri == 'https://secure.example.com/entity' + assert hasattr(value, 'type') + assert value.type == IRI @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -135,8 +136,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert hasattr(value, 'value') assert value.value == 'regular entity name' - assert hasattr(value, 'is_uri') - assert value.is_uri == False + assert hasattr(value, 'type') + assert value.type == LITERAL @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -428,14 +429,14 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): assert len(result) == 3 # Check URI entities - uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri] + uri_entities = [entity for entity in result if entity.type == IRI] assert len(uri_entities) == 2 - uri_values = [entity.value for entity in uri_entities] + uri_values = [entity.iri for entity in uri_entities] assert 'http://example.com/entity1' in uri_values assert 'https://secure.example.com/entity2' in uri_values - + # Check regular entities - regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri] + regular_entities = [entity for entity in result if entity.type == LITERAL] assert len(regular_entities) == 1 assert regular_entities[0].value == 'regular entity' diff --git a/tests/unit/test_query/test_memgraph_user_collection_query.py b/tests/unit/test_query/test_memgraph_user_collection_query.py index 772d4f84..038fb438 100644 --- a/tests/unit/test_query/test_memgraph_user_collection_query.py +++ b/tests/unit/test_query/test_memgraph_user_collection_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.triples.memgraph.service import Processor -from trustgraph.schema import TriplesQueryRequest, Value +from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL class TestMemgraphQueryUserCollectionIsolation: @@ -24,9 +24,9 @@ class TestMemgraphQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="test_object", is_uri=False), + s=Term(type=IRI, iri="http://example.com/s"), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="test_object"), limit=1000 ) @@ -65,8 +65,8 @@ class TestMemgraphQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), + p=Term(type=IRI, iri="http://example.com/p"), o=None, limit=1000 ) @@ -105,9 +105,9 @@ class TestMemgraphQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, - o=Value(value="http://example.com/o", is_uri=True), + o=Term(type=IRI, iri="http://example.com/o"), limit=1000 ) @@ -145,7 +145,7 @@ class TestMemgraphQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, o=None, limit=1000 @@ -185,8 +185,8 @@ class TestMemgraphQueryUserCollectionIsolation: user="test_user", collection="test_collection", s=None, - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="literal", is_uri=False), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="literal"), limit=1000 ) @@ -225,7 +225,7 @@ class TestMemgraphQueryUserCollectionIsolation: user="test_user", collection="test_collection", s=None, - p=Value(value="http://example.com/p", is_uri=True), + p=Term(type=IRI, iri="http://example.com/p"), o=None, limit=1000 ) @@ -265,7 +265,7 @@ class TestMemgraphQueryUserCollectionIsolation: collection="test_collection", s=None, p=None, - o=Value(value="test_value", is_uri=False), + o=Term(type=LITERAL, value="test_value"), limit=1000 ) @@ -355,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation: # Query without user/collection fields query = TriplesQueryRequest( - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, o=None, limit=1000 @@ -385,7 +385,7 @@ class TestMemgraphQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, o=None, limit=1000 @@ -416,17 +416,17 @@ class TestMemgraphQueryUserCollectionIsolation: assert len(result) == 2 # First triple (literal object) - assert result[0].s.value == "http://example.com/s" - assert result[0].s.is_uri == True - assert result[0].p.value == "http://example.com/p1" - assert result[0].p.is_uri == True + assert result[0].s.iri == "http://example.com/s" + assert result[0].s.type == IRI + assert result[0].p.iri == "http://example.com/p1" + assert result[0].p.type == IRI assert result[0].o.value == "literal_value" - assert result[0].o.is_uri == False - + assert result[0].o.type == LITERAL + # Second triple (URI object) - assert result[1].s.value == "http://example.com/s" - assert result[1].s.is_uri == True - assert result[1].p.value == "http://example.com/p2" - assert result[1].p.is_uri == True - assert result[1].o.value == "http://example.com/o" - assert result[1].o.is_uri == True \ No newline at end of file + assert result[1].s.iri == "http://example.com/s" + assert result[1].s.type == IRI + assert result[1].p.iri == "http://example.com/p2" + assert result[1].p.type == IRI + assert result[1].o.iri == "http://example.com/o" + assert result[1].o.type == IRI \ No newline at end of file diff --git a/tests/unit/test_query/test_neo4j_user_collection_query.py b/tests/unit/test_query/test_neo4j_user_collection_query.py index bf23680c..d9cf1eb4 100644 --- a/tests/unit/test_query/test_neo4j_user_collection_query.py +++ b/tests/unit/test_query/test_neo4j_user_collection_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.triples.neo4j.service import Processor -from trustgraph.schema import TriplesQueryRequest, Value +from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL class TestNeo4jQueryUserCollectionIsolation: @@ -24,21 +24,23 @@ class TestNeo4jQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="test_object", is_uri=False) + s=Term(type=IRI, iri="http://example.com/s"), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="test_object"), + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify SPO query for literal includes user/collection expected_query = ( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN $src as src" + "RETURN $src as src " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -63,23 +65,25 @@ class TestNeo4jQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), - o=None + s=Term(type=IRI, iri="http://example.com/s"), + p=Term(type=IRI, iri="http://example.com/p"), + o=None, + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify SP query for literals includes user/collection expected_literal_query = ( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN dest.value as dest" + "RETURN dest.value as dest " + "LIMIT 10" ) - + mock_driver.execute_query.assert_any_call( expected_literal_query, src="http://example.com/s", @@ -88,13 +92,14 @@ class TestNeo4jQueryUserCollectionIsolation: collection="test_collection", database_='neo4j' ) - + # Verify SP query for nodes includes user/collection expected_node_query = ( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN dest.uri as dest" + "RETURN dest.uri as dest " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -118,21 +123,23 @@ class TestNeo4jQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, - o=Value(value="http://example.com/o", is_uri=True) + o=Term(type=IRI, iri="http://example.com/o"), + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify SO query for nodes includes user/collection expected_query = ( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {uri: $uri, user: $user, collection: $collection}) " - "RETURN rel.uri as rel" + "RETURN rel.uri as rel " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -156,23 +163,25 @@ class TestNeo4jQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, - o=None + o=None, + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify S query includes user/collection expected_query = ( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN rel.uri as rel, dest.value as dest" + "RETURN rel.uri as rel, dest.value as dest " + "LIMIT 10" ) - + mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", @@ -194,20 +203,22 @@ class TestNeo4jQueryUserCollectionIsolation: user="test_user", collection="test_collection", s=None, - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="literal", is_uri=False) + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="literal"), + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify PO query for literals includes user/collection expected_query = ( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN src.uri as src" + "RETURN src.uri as src " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -232,20 +243,22 @@ class TestNeo4jQueryUserCollectionIsolation: user="test_user", collection="test_collection", s=None, - p=Value(value="http://example.com/p", is_uri=True), - o=None + p=Term(type=IRI, iri="http://example.com/p"), + o=None, + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify P query includes user/collection expected_query = ( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN src.uri as src, dest.value as dest" + "RETURN src.uri as src, dest.value as dest " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -270,19 +283,21 @@ class TestNeo4jQueryUserCollectionIsolation: collection="test_collection", s=None, p=None, - o=Value(value="test_value", is_uri=False) + o=Term(type=LITERAL, value="test_value"), + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify O query for literals includes user/collection expected_query = ( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel" + "RETURN src.uri as src, rel.uri as rel " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -307,34 +322,37 @@ class TestNeo4jQueryUserCollectionIsolation: collection="test_collection", s=None, p=None, - o=None + o=None, + limit=10 ) - + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - + await processor.query_triples(query) - + # Verify wildcard query for literals includes user/collection expected_literal_query = ( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel, dest.value as dest" + "RETURN src.uri as src, rel.uri as rel, dest.value as dest " + "LIMIT 10" ) - + mock_driver.execute_query.assert_any_call( expected_literal_query, user="test_user", collection="test_collection", database_='neo4j' ) - + # Verify wildcard query for nodes includes user/collection expected_node_query = ( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel, dest.uri as dest" + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " + "LIMIT 10" ) mock_driver.execute_query.assert_any_call( @@ -355,9 +373,10 @@ class TestNeo4jQueryUserCollectionIsolation: # Query without user/collection fields query = TriplesQueryRequest( - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, - o=None + o=None, + limit=10 ) mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) @@ -384,47 +403,48 @@ class TestNeo4jQueryUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/s", is_uri=True), + s=Term(type=IRI, iri="http://example.com/s"), p=None, - o=None + o=None, + limit=10 ) - + # Mock some results mock_record1 = MagicMock() mock_record1.data.return_value = { "rel": "http://example.com/p1", "dest": "literal_value" } - + mock_record2 = MagicMock() mock_record2.data.return_value = { "rel": "http://example.com/p2", "dest": "http://example.com/o" } - + # Return results for literal query, empty for node query mock_driver.execute_query.side_effect = [ ([mock_record1], MagicMock(), MagicMock()), # Literal query ([mock_record2], MagicMock(), MagicMock()) # Node query ] - + result = await processor.query_triples(query) # Verify results are proper Triple objects assert len(result) == 2 # First triple (literal object) - assert result[0].s.value == "http://example.com/s" - assert result[0].s.is_uri == True - assert result[0].p.value == "http://example.com/p1" - assert result[0].p.is_uri == True + assert result[0].s.iri == "http://example.com/s" + assert result[0].s.type == IRI + assert result[0].p.iri == "http://example.com/p1" + assert result[0].p.type == IRI assert result[0].o.value == "literal_value" - assert result[0].o.is_uri == False - + assert result[0].o.type == LITERAL + # Second triple (URI object) - assert result[1].s.value == "http://example.com/s" - assert result[1].s.is_uri == True - assert result[1].p.value == "http://example.com/p2" - assert result[1].p.is_uri == True - assert result[1].o.value == "http://example.com/o" - assert result[1].o.is_uri == True \ No newline at end of file + assert result[1].s.iri == "http://example.com/s" + assert result[1].s.type == IRI + assert result[1].p.iri == "http://example.com/p2" + assert result[1].p.type == IRI + assert result[1].o.iri == "http://example.com/o" + assert result[1].o.type == IRI \ No newline at end of file diff --git a/tests/unit/test_query/test_objects_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py similarity index 52% rename from tests/unit/test_query/test_objects_cassandra_query.py rename to tests/unit/test_query/test_rows_cassandra_query.py index ab11d5a1..879a81c5 100644 --- a/tests/unit/test_query/test_objects_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -1,10 +1,11 @@ """ -Unit tests for Cassandra Objects GraphQL Query Processor +Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation) Tests the business logic of the GraphQL query processor including: -- GraphQL schema generation from RowSchema -- Query execution and validation -- CQL translation logic +- Schema configuration handling +- Query execution using unified rows table +- Name sanitization +- GraphQL query execution - Message processing logic """ @@ -12,119 +13,91 @@ import pytest from unittest.mock import MagicMock, AsyncMock, patch import json -import strawberry -from strawberry import Schema - -from trustgraph.query.objects.cassandra.service import Processor -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from trustgraph.schema import RowSchema, Field -class TestObjectsGraphQLQueryLogic: - """Test business logic without external dependencies""" - - def test_get_python_type_mapping(self): - """Test schema field type conversion to Python types""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - - # Basic type mappings - assert processor.get_python_type("string") == str - assert processor.get_python_type("integer") == int - assert processor.get_python_type("float") == float - assert processor.get_python_type("boolean") == bool - assert processor.get_python_type("timestamp") == str - assert processor.get_python_type("date") == str - assert processor.get_python_type("time") == str - assert processor.get_python_type("uuid") == str - - # Unknown type defaults to str - assert processor.get_python_type("unknown_type") == str - - def test_create_graphql_type_basic_fields(self): - """Test GraphQL type creation for basic field types""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create test schema - schema = RowSchema( - name="test_table", - description="Test table", - fields=[ - Field( - name="id", - type="string", - primary=True, - required=True, - description="Primary key" - ), - Field( - name="name", - type="string", - required=True, - description="Name field" - ), - Field( - name="age", - type="integer", - required=False, - description="Optional age" - ), - Field( - name="active", - type="boolean", - required=False, - description="Status flag" - ) - ] - ) - - # Create GraphQL type - graphql_type = processor.create_graphql_type("test_table", schema) - - # Verify type was created - assert graphql_type is not None - assert hasattr(graphql_type, '__name__') - assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower() +class TestRowsGraphQLQueryLogic: + """Test business logic for unified table query implementation""" def test_sanitize_name_cassandra_compatibility(self): """Test name sanitization for Cassandra field names""" processor = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - - # Test field name sanitization (matches storage processor) + + # Test field name sanitization (uses r_ prefix like storage processor) assert processor.sanitize_name("simple_field") == "simple_field" assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes" assert processor.sanitize_name("field.with.dots") == "field_with_dots" - assert processor.sanitize_name("123_field") == "o_123_field" + assert processor.sanitize_name("123_field") == "r_123_field" assert processor.sanitize_name("field with spaces") == "field_with_spaces" assert processor.sanitize_name("special!@#chars") == "special___chars" assert processor.sanitize_name("UPPERCASE") == "uppercase" assert processor.sanitize_name("CamelCase") == "camelcase" - def test_sanitize_table_name(self): - """Test table name sanitization (always gets o_ prefix)""" + def test_get_index_names(self): + """Test extraction of index names from schema""" processor = MagicMock() - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - - # Table names always get o_ prefix - assert processor.sanitize_table("simple_table") == "o_simple_table" - assert processor.sanitize_table("Table-Name") == "o_table_name" - assert processor.sanitize_table("123table") == "o_123table" - assert processor.sanitize_table("") == "o_" + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + schema = RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string"), # Not indexed + Field(name="status", type="string", indexed=True) + ] + ) + + index_names = processor.get_index_names(schema) + + assert "id" in index_names + assert "category" in index_names + assert "status" in index_names + assert "name" not in index_names + assert len(index_names) == 3 + + def test_find_matching_index_exact_match(self): + """Test finding matching index for exact match query""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + + schema = RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string") # Not indexed + ] + ) + + # Filter on indexed field should return match + filters = {"category": "electronics"} + result = processor.find_matching_index(schema, filters) + assert result is not None + assert result[0] == "category" + assert result[1] == ["electronics"] + + # Filter on non-indexed field should return None + filters = {"name": "test"} + result = processor.find_matching_index(schema, filters) + assert result is None @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" processor = MagicMock() processor.schemas = {} - processor.graphql_types = {} - processor.graphql_schema = None - processor.config_key = "schema" # Set the config key - processor.generate_graphql_schema = AsyncMock() + processor.config_key = "schema" + processor.schema_builder = MagicMock() + processor.schema_builder.clear = MagicMock() + processor.schema_builder.add_schema = MagicMock() + processor.schema_builder.build = MagicMock(return_value=MagicMock()) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - + # Create test config schema_config = { "schema": { @@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic: }) } } - + # Process config await processor.on_schema_config(schema_config, version=1) - + # Verify schema was loaded assert "customer" in processor.schemas schema = processor.schemas["customer"] assert schema.name == "customer" assert len(schema.fields) == 3 - + # Verify fields id_field = next(f for f in schema.fields if f.name == "id") assert id_field.primary is True - # The field should have been created correctly from JSON - # Let's test what we can verify - that the field has the right attributes - assert hasattr(id_field, 'required') # Has the required attribute - assert hasattr(id_field, 'primary') # Has the primary attribute - + email_field = next(f for f in schema.fields if f.name == "email") assert email_field.indexed is True - + status_field = next(f for f in schema.fields if f.name == "status") assert status_field.enum_values == ["active", "inactive"] - - # Verify GraphQL schema regeneration was called - processor.generate_graphql_schema.assert_called_once() - def test_cql_query_building_basic(self): - """Test basic CQL query construction""" - processor = MagicMock() - processor.session = MagicMock() - processor.connect_cassandra = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor) - processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) - - # Mock session execute to capture the query - mock_result = [] - processor.session.execute.return_value = mock_result - - # Create test schema - schema = RowSchema( - name="test_table", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string", indexed=True), - Field(name="status", type="string") - ] - ) - - # Test query building - asyncio = pytest.importorskip("asyncio") - - async def run_test(): - await processor.query_cassandra( - user="test_user", - collection="test_collection", - schema_name="test_table", - row_schema=schema, - filters={"name": "John", "invalid_filter": "ignored"}, - limit=10 - ) - - # Run the async test - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(run_test()) - finally: - loop.close() - - # Verify Cassandra connection and query execution - processor.connect_cassandra.assert_called_once() - processor.session.execute.assert_called_once() - - # Verify the query structure (can't easily test exact query without complex mocking) - call_args = processor.session.execute.call_args - query = call_args[0][0] # First positional argument is the query - params = call_args[0][1] # Second positional argument is parameters - - # Basic query structure checks - assert "SELECT * FROM test_user.o_test_table" in query - assert "WHERE" in query - assert "collection = %s" in query - assert "LIMIT 10" in query - - # Parameters should include collection and name filter - assert "test_collection" in params - assert "John" in params + # Verify schema builder was called + processor.schema_builder.add_schema.assert_called_once() + processor.schema_builder.build.assert_called_once() @pytest.mark.asyncio async def test_graphql_context_handling(self): @@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.graphql_schema = AsyncMock() processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) - + # Mock schema execution mock_result = MagicMock() mock_result.data = {"customers": [{"id": "1", "name": "Test"}]} mock_result.errors = None processor.graphql_schema.execute.return_value = mock_result - + result = await processor.execute_graphql_query( query='{ customers { id name } }', variables={}, @@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic: user="test_user", collection="test_collection" ) - + # Verify schema.execute was called with correct context processor.graphql_schema.execute.assert_called_once() call_args = processor.graphql_schema.execute.call_args - + # Verify context was passed - context = call_args[1]['context_value'] # keyword argument + context = call_args[1]['context_value'] assert context["processor"] == processor assert context["user"] == "test_user" assert context["collection"] == "test_collection" - + # Verify result structure assert "data" in result assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]} @@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.graphql_schema = AsyncMock() processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) - - # Create a simple object to simulate GraphQL error instead of MagicMock + + # Create a simple object to simulate GraphQL error class MockError: def __init__(self, message, path, extensions): self.message = message self.path = path self.extensions = extensions - + def __str__(self): return self.message - + mock_error = MockError( message="Field 'invalid_field' doesn't exist", path=["customers", "0", "invalid_field"], extensions={"code": "FIELD_NOT_FOUND"} ) - + mock_result = MagicMock() mock_result.data = None mock_result.errors = [mock_error] processor.graphql_schema.execute.return_value = mock_result - + result = await processor.execute_graphql_query( query='{ customers { invalid_field } }', variables={}, operation_name=None, - user="test_user", + user="test_user", collection="test_collection" ) - + # Verify error handling assert "errors" in result assert len(result["errors"]) == 1 - + error = result["errors"][0] assert error["message"] == "Field 'invalid_field' doesn't exist" - assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path + assert error["path"] == ["customers", "0", "invalid_field"] assert error["extensions"] == {"code": "FIELD_NOT_FOUND"} - def test_schema_generation_basic_structure(self): - """Test basic GraphQL schema generation structure""" - processor = MagicMock() - processor.schemas = { - "customer": RowSchema( - name="customer", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string") - ] - ) - } - processor.graphql_types = {} - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Test individual type creation (avoiding the full schema generation which has annotation issues) - graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"]) - processor.graphql_types["customer"] = graphql_type - - # Verify type was created - assert len(processor.graphql_types) == 1 - assert "customer" in processor.graphql_types - assert processor.graphql_types["customer"] is not None - @pytest.mark.asyncio async def test_message_processing_success(self): """Test successful message processing flow""" processor = MagicMock() processor.execute_graphql_query = AsyncMock() processor.on_message = Processor.on_message.__get__(processor, Processor) - + # Mock successful query result processor.execute_graphql_query.return_value = { "data": {"customers": [{"id": "1", "name": "John"}]}, "errors": [], - "extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String()) + "extensions": {} } - + # Create mock message mock_msg = MagicMock() - mock_request = ObjectsQueryRequest( + mock_request = RowsQueryRequest( user="test_user", - collection="test_collection", + collection="test_collection", query='{ customers { id name } }', variables={}, operation_name=None ) mock_msg.value.return_value = mock_request mock_msg.properties.return_value = {"id": "test-123"} - + # Mock flow mock_flow = MagicMock() mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow - + # Process message await processor.on_message(mock_msg, None, mock_flow) - + # Verify query was executed processor.execute_graphql_query.assert_called_once_with( query='{ customers { id name } }', @@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic: user="test_user", collection="test_collection" ) - + # Verify response was sent mock_response_flow.send.assert_called_once() response_call = mock_response_flow.send.call_args[0][0] - + # Verify response structure - assert isinstance(response_call, ObjectsQueryResponse) + assert isinstance(response_call, RowsQueryResponse) assert response_call.error is None assert '"customers"' in response_call.data # JSON encoded assert len(response_call.errors) == 0 @@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.execute_graphql_query = AsyncMock() processor.on_message = Processor.on_message.__get__(processor, Processor) - + # Mock query execution error processor.execute_graphql_query.side_effect = RuntimeError("No schema available") - + # Create mock message mock_msg = MagicMock() - mock_request = ObjectsQueryRequest( + mock_request = RowsQueryRequest( user="test_user", collection="test_collection", query='{ invalid_query }', @@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic: ) mock_msg.value.return_value = mock_request mock_msg.properties.return_value = {"id": "test-456"} - + # Mock flow mock_flow = MagicMock() mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow - + # Process message await processor.on_message(mock_msg, None, mock_flow) - + # Verify error response was sent mock_response_flow.send.assert_called_once() response_call = mock_response_flow.send.call_args[0][0] - + # Verify error response structure - assert isinstance(response_call, ObjectsQueryResponse) + assert isinstance(response_call, RowsQueryResponse) assert response_call.error is not None - assert response_call.error.type == "objects-query-error" + assert response_call.error.type == "rows-query-error" assert "No schema available" in response_call.error.message assert response_call.data is None -class TestCQLQueryGeneration: - """Test CQL query generation logic in isolation""" - - def test_partition_key_inclusion(self): - """Test that collection is always included in queries""" +class TestUnifiedTableQueries: + """Test queries against the unified rows table""" + + @pytest.mark.asyncio + async def test_query_with_index_match(self): + """Test query execution with matching index""" processor = MagicMock() + processor.session = MagicMock() + processor.connect_cassandra = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - - # Mock the query building (simplified version) - keyspace = processor.sanitize_name("test_user") - table = processor.sanitize_table("test_table") - - query = f"SELECT * FROM {keyspace}.{table}" - where_clauses = ["collection = %s"] - - assert "collection = %s" in where_clauses - assert keyspace == "test_user" - assert table == "o_test_table" - + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) + + # Mock session execute to return test data + mock_row = MagicMock() + mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"} + processor.session.execute.return_value = [mock_row] + + schema = RowSchema( + name="products", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string") + ] + ) + + # Query with filter on indexed field + results = await processor.query_cassandra( + user="test_user", + collection="test_collection", + schema_name="products", + row_schema=schema, + filters={"category": "electronics"}, + limit=10 + ) + + # Verify Cassandra was connected and queried + processor.connect_cassandra.assert_called_once() + processor.session.execute.assert_called_once() + + # Verify query structure - should query unified rows table + call_args = processor.session.execute.call_args + query = call_args[0][0] + params = call_args[0][1] + + assert "SELECT data, source FROM test_user.rows" in query + assert "collection = %s" in query + assert "schema_name = %s" in query + assert "index_name = %s" in query + assert "index_value = %s" in query + + assert params[0] == "test_collection" + assert params[1] == "products" + assert params[2] == "category" + assert params[3] == ["electronics"] + + # Verify results + assert len(results) == 1 + assert results[0]["id"] == "123" + assert results[0]["category"] == "electronics" + + @pytest.mark.asyncio + async def test_query_without_index_match(self): + """Test query execution without matching index (scan mode)""" + processor = MagicMock() + processor.session = MagicMock() + processor.connect_cassandra = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) + + # Mock session execute to return test data + mock_row1 = MagicMock() + mock_row1.data = {"id": "1", "name": "Product A", "price": "100"} + mock_row2 = MagicMock() + mock_row2.data = {"id": "2", "name": "Product B", "price": "200"} + processor.session.execute.return_value = [mock_row1, mock_row2] + + schema = RowSchema( + name="products", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string"), # Not indexed + Field(name="price", type="string") # Not indexed + ] + ) + + # Query with filter on non-indexed field + results = await processor.query_cassandra( + user="test_user", + collection="test_collection", + schema_name="products", + row_schema=schema, + filters={"name": "Product A"}, + limit=10 + ) + + # Query should use ALLOW FILTERING for scan + call_args = processor.session.execute.call_args + query = call_args[0][0] + + assert "ALLOW FILTERING" in query + + # Should post-filter results + assert len(results) == 1 + assert results[0]["name"] == "Product A" + + +class TestFilterMatching: + """Test filter matching logic""" + + def test_matches_filters_exact_match(self): + """Test exact match filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="status", type="string")]) + + row = {"status": "active", "name": "test"} + assert processor._matches_filters(row, {"status": "active"}, schema) is True + assert processor._matches_filters(row, {"status": "inactive"}, schema) is False + + def test_matches_filters_comparison_operators(self): + """Test comparison operators in filters""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="price", type="float")]) + + row = {"price": "100.0"} + + # Greater than + assert processor._matches_filters(row, {"price_gt": 50}, schema) is True + assert processor._matches_filters(row, {"price_gt": 150}, schema) is False + + # Less than + assert processor._matches_filters(row, {"price_lt": 150}, schema) is True + assert processor._matches_filters(row, {"price_lt": 50}, schema) is False + + # Greater than or equal + assert processor._matches_filters(row, {"price_gte": 100}, schema) is True + assert processor._matches_filters(row, {"price_gte": 101}, schema) is False + + # Less than or equal + assert processor._matches_filters(row, {"price_lte": 100}, schema) is True + assert processor._matches_filters(row, {"price_lte": 99}, schema) is False + + def test_matches_filters_contains(self): + """Test contains filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="description", type="string")]) + + row = {"description": "A great product for everyone"} + + assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True + assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False + + def test_matches_filters_in_list(self): + """Test in-list filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="status", type="string")]) + + row = {"status": "active"} + + assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True + assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False + + +class TestIndexedFieldFiltering: + """Test that only indexed or primary key fields can be directly filtered""" + def test_indexed_field_filtering(self): """Test that only indexed or primary key fields can be filtered""" - # Create schema with mixed field types schema = RowSchema( name="test", fields=[ Field(name="id", type="string", primary=True), - Field(name="indexed_field", type="string", indexed=True), + Field(name="indexed_field", type="string", indexed=True), Field(name="normal_field", type="string", indexed=False), Field(name="another_field", type="string") ] ) - + filters = { "id": "test123", # Primary key - should be included "indexed_field": "value", # Indexed - should be included "normal_field": "ignored", # Not indexed - should be ignored "another_field": "also_ignored" # Not indexed - should be ignored } - + # Simulate the filtering logic from the processor valid_filters = [] for field_name, value in filters.items(): @@ -492,7 +531,7 @@ class TestCQLQueryGeneration: schema_field = next((f for f in schema.fields if f.name == field_name), None) if schema_field and (schema_field.indexed or schema_field.primary): valid_filters.append((field_name, value)) - + # Only id and indexed_field should be included assert len(valid_filters) == 2 field_names = [f[0] for f in valid_filters] @@ -500,52 +539,3 @@ class TestCQLQueryGeneration: assert "indexed_field" in field_names assert "normal_field" not in field_names assert "another_field" not in field_names - - -class TestGraphQLSchemaGeneration: - """Test GraphQL schema generation in detail""" - - def test_field_type_annotations(self): - """Test that GraphQL types have correct field annotations""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create schema with various field types - schema = RowSchema( - name="test", - fields=[ - Field(name="id", type="string", required=True, primary=True), - Field(name="count", type="integer", required=True), - Field(name="price", type="float", required=False), - Field(name="active", type="boolean", required=False), - Field(name="optional_text", type="string", required=False) - ] - ) - - # Create GraphQL type - graphql_type = processor.create_graphql_type("test", schema) - - # Verify type was created successfully - assert graphql_type is not None - - def test_basic_type_creation(self): - """Test that GraphQL types are created correctly""" - processor = MagicMock() - processor.schemas = { - "customer": RowSchema( - name="customer", - fields=[Field(name="id", type="string", primary=True)] - ) - } - processor.graphql_types = {} - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create GraphQL type directly - graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"]) - processor.graphql_types["customer"] = graphql_type - - # Verify customer type was created - assert "customer" in processor.graphql_types - assert processor.graphql_types["customer"] is not None \ No newline at end of file diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index f5be4961..480f2ee1 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -5,8 +5,8 @@ Tests for Cassandra triples query service import pytest from unittest.mock import MagicMock, patch -from trustgraph.query.triples.cassandra.service import Processor -from trustgraph.schema import Value +from trustgraph.query.triples.cassandra.service import Processor, create_term +from trustgraph.schema import Term, IRI, LITERAL class TestCassandraQueryProcessor: @@ -21,94 +21,101 @@ class TestCassandraQueryProcessor: graph_host='localhost' ) - def test_create_value_with_http_uri(self, processor): - """Test create_value with HTTP URI""" - result = processor.create_value("http://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + def test_create_term_with_http_uri(self, processor): + """Test create_term with HTTP URI""" + result = create_term("http://example.com/resource") - def test_create_value_with_https_uri(self, processor): - """Test create_value with HTTPS URI""" - result = processor.create_value("https://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "https://example.com/resource" - assert result.is_uri is True + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI - def test_create_value_with_literal(self, processor): - """Test create_value with literal value""" - result = processor.create_value("just a literal string") - - assert isinstance(result, Value) + def test_create_term_with_https_uri(self, processor): + """Test create_term with HTTPS URI""" + result = create_term("https://example.com/resource") + + assert isinstance(result, Term) + assert result.iri == "https://example.com/resource" + assert result.type == IRI + + def test_create_term_with_literal(self, processor): + """Test create_term with literal value""" + result = create_term("just a literal string") + + assert isinstance(result, Term) assert result.value == "just a literal string" - assert result.is_uri is False + assert result.type == LITERAL - def test_create_value_with_empty_string(self, processor): - """Test create_value with empty string""" - result = processor.create_value("") - - assert isinstance(result, Value) + def test_create_term_with_empty_string(self, processor): + """Test create_term with empty string""" + result = create_term("") + + assert isinstance(result, Term) assert result.value == "" - assert result.is_uri is False + assert result.type == LITERAL - def test_create_value_with_partial_uri(self, processor): - """Test create_value with string that looks like URI but isn't complete""" - result = processor.create_value("http") - - assert isinstance(result, Value) + def test_create_term_with_partial_uri(self, processor): + """Test create_term with string that looks like URI but isn't complete""" + result = create_term("http") + + assert isinstance(result, Term) assert result.value == "http" - assert result.is_uri is False + assert result.type == LITERAL - def test_create_value_with_ftp_uri(self, processor): - """Test create_value with FTP URI (should not be detected as URI)""" - result = processor.create_value("ftp://example.com/file") - - assert isinstance(result, Value) + def test_create_term_with_ftp_uri(self, processor): + """Test create_term with FTP URI (should not be detected as URI)""" + result = create_term("ftp://example.com/file") + + assert isinstance(result, Term) assert result.value == "ftp://example.com/file" - assert result.is_uri is False + assert result.type == LITERAL @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_spo_query(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_spo_query(self, mock_kg_class): """Test querying triples with subject, predicate, and object specified""" - from trustgraph.schema import TriplesQueryRequest, Value - - # Setup mock TrustGraph + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + + # Setup mock TrustGraph via factory function mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - mock_tg_instance.get_spo.return_value = None # SPO query returns None if found - + mock_kg_class.return_value = mock_tg_instance + # SPO query returns a list of results (with mock graph attribute) + mock_result = MagicMock() + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None + mock_result.o = 'test_object' + mock_tg_instance.get_spo.return_value = [mock_result] + processor = Processor( taskgroup=MagicMock(), id='test-cassandra-query', cassandra_host='localhost' ) - + # Create query request with all SPO values query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), - o=Value(value='test_object', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), + o=Term(type=LITERAL, value='test_object'), limit=100 ) - + result = await processor.query_triples(query) - + # Verify KnowledgeGraph was created with correct parameters - mock_trustgraph.assert_called_once_with( + mock_kg_class.assert_called_once_with( hosts=['localhost'], keyspace='test_user' ) - + # Verify get_spo was called with correct parameters mock_tg_instance.get_spo.assert_called_once_with( - 'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100 + 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 ) - + # Verify result contains the queried triple assert len(result) == 1 assert result[0].s.value == 'test_subject' @@ -143,154 +150,174 @@ class TestCassandraQueryProcessor: assert processor.table is None @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_sp_pattern(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_sp_pattern(self, mock_kg_class): """Test SP query pattern (subject and predicate, no object)""" - from trustgraph.schema import TriplesQueryRequest, Value - - # Setup mock TrustGraph and response + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + + # Setup mock TrustGraph via factory function mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + mock_result = MagicMock() mock_result.o = 'result_object' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_sp.return_value = [mock_result] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), o=None, limit=50 ) - + result = await processor.query_triples(query) - - mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50) + + mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 assert result[0].s.value == 'test_subject' assert result[0].p.value == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_s_pattern(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_s_pattern(self, mock_kg_class): """Test S query pattern (subject only)""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + mock_result = MagicMock() mock_result.p = 'result_predicate' mock_result.o = 'result_object' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_s.return_value = [mock_result] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), p=None, o=None, limit=25 ) - + result = await processor.query_triples(query) - - mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25) + + mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 assert result[0].s.value == 'test_subject' assert result[0].p.value == 'result_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_p_pattern(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_p_pattern(self, mock_kg_class): """Test P query pattern (predicate only)""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + mock_result = MagicMock() mock_result.s = 'result_subject' mock_result.o = 'result_object' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_p.return_value = [mock_result] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', s=None, - p=Value(value='test_predicate', is_uri=False), + p=Term(type=LITERAL, value='test_predicate'), o=None, limit=10 ) - + result = await processor.query_triples(query) - - mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10) + + mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 assert result[0].s.value == 'result_subject' assert result[0].p.value == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_o_pattern(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_o_pattern(self, mock_kg_class): """Test O query pattern (object only)""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + mock_result = MagicMock() mock_result.s = 'result_subject' mock_result.p = 'result_predicate' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_o.return_value = [mock_result] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', s=None, p=None, - o=Value(value='test_object', is_uri=False), + o=Term(type=LITERAL, value='test_object'), limit=75 ) - + result = await processor.query_triples(query) - - mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75) + + mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 assert result[0].s.value == 'result_subject' assert result[0].p.value == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_get_all_pattern(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_get_all_pattern(self, mock_kg_class): """Test query pattern with no constraints (get all)""" from trustgraph.schema import TriplesQueryRequest - + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + mock_result = MagicMock() mock_result.s = 'all_subject' mock_result.p = 'all_predicate' mock_result.o = 'all_object' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_all.return_value = [mock_result] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', @@ -299,9 +326,9 @@ class TestCassandraQueryProcessor: o=None, limit=1000 ) - + result = await processor.query_triples(query) - + mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.value == 'all_subject' @@ -372,37 +399,44 @@ class TestCassandraQueryProcessor: run() - mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n') + mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o, g) quad pattern, some values may be\nnull. Output is a list of quads.\n') @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_with_authentication(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_with_authentication(self, mock_kg_class): """Test querying with username and password authentication""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - mock_tg_instance.get_spo.return_value = None - + mock_kg_class.return_value = mock_tg_instance + # SPO query returns a list of results + mock_result = MagicMock() + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None + mock_result.o = 'test_object' + mock_tg_instance.get_spo.return_value = [mock_result] + processor = Processor( taskgroup=MagicMock(), cassandra_username='authuser', cassandra_password='authpass' ) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), - o=Value(value='test_object', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), + o=Term(type=LITERAL, value='test_object'), limit=100 ) - + await processor.query_triples(query) - + # Verify KnowledgeGraph was created with authentication - mock_trustgraph.assert_called_once_with( + mock_kg_class.assert_called_once_with( hosts=['cassandra'], # Updated default keyspace='test_user', username='authuser', @@ -410,128 +444,154 @@ class TestCassandraQueryProcessor: ) @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_table_reuse(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_table_reuse(self, mock_kg_class): """Test that TrustGraph is reused for same table""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - mock_tg_instance.get_spo.return_value = None - + mock_kg_class.return_value = mock_tg_instance + # SPO query returns a list of results + mock_result = MagicMock() + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None + mock_result.o = 'test_object' + mock_tg_instance.get_spo.return_value = [mock_result] + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), - o=Value(value='test_object', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), + o=Term(type=LITERAL, value='test_object'), limit=100 ) - + # First query should create TrustGraph await processor.query_triples(query) - assert mock_trustgraph.call_count == 1 - + assert mock_kg_class.call_count == 1 + # Second query with same table should reuse TrustGraph await processor.query_triples(query) - assert mock_trustgraph.call_count == 1 # Should not increase + assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_table_switching(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_table_switching(self, mock_kg_class): """Test table switching creates new TrustGraph""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance1 = MagicMock() mock_tg_instance2 = MagicMock() - mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2] - + mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2] + + # Setup mock results for both instances + mock_result = MagicMock() + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None + mock_result.p = 'p' + mock_result.o = 'o' + mock_tg_instance1.get_s.return_value = [mock_result] + mock_tg_instance2.get_s.return_value = [mock_result] + processor = Processor(taskgroup=MagicMock()) - + # First query query1 = TriplesQueryRequest( user='user1', collection='collection1', - s=Value(value='test_subject', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), p=None, o=None, limit=100 ) - + await processor.query_triples(query1) assert processor.table == 'user1' - + # Second query with different table query2 = TriplesQueryRequest( user='user2', collection='collection2', - s=Value(value='test_subject', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), p=None, o=None, limit=100 ) - + await processor.query_triples(query2) assert processor.table == 'user2' - + # Verify TrustGraph was created twice - assert mock_trustgraph.call_count == 2 + assert mock_kg_class.call_count == 2 @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_exception_handling(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_exception_handling(self, mock_kg_class): """Test exception handling during query execution""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance mock_tg_instance.get_spo.side_effect = Exception("Query failed") - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), - o=Value(value='test_object', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), + o=Term(type=LITERAL, value='test_object'), limit=100 ) - + with pytest.raises(Exception, match="Query failed"): await processor.query_triples(query) @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_query_triples_multiple_results(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_query_triples_multiple_results(self, mock_kg_class): """Test query returning multiple results""" - from trustgraph.schema import TriplesQueryRequest, Value - + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL + mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + # Mock multiple results mock_result1 = MagicMock() mock_result1.o = 'object1' + mock_result1.g = '' + mock_result1.otype = None + mock_result1.dtype = None + mock_result1.lang = None mock_result2 = MagicMock() mock_result2.o = 'object2' + mock_result2.g = '' + mock_result2.otype = None + mock_result2.dtype = None + mock_result2.lang = None mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] - + processor = Processor(taskgroup=MagicMock()) - + query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), - p=Value(value='test_predicate', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), + p=Term(type=LITERAL, value='test_predicate'), o=None, limit=100 ) - + result = await processor.query_triples(query) - + assert len(result) == 2 assert result[0].o.value == 'object1' assert result[1].o.value == 'object2' @@ -541,16 +601,20 @@ class TestCassandraQueryPerformanceOptimizations: """Test cases for multi-table performance optimizations in query service""" @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_get_po_query_optimization(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_get_po_query_optimization(self, mock_kg_class): """Test that get_po queries use optimized table (no ALLOW FILTERING)""" - from trustgraph.schema import TriplesQueryRequest, Value + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance mock_result = MagicMock() mock_result.s = 'result_subject' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_po.return_value = [mock_result] processor = Processor(taskgroup=MagicMock()) @@ -560,8 +624,8 @@ class TestCassandraQueryPerformanceOptimizations: user='test_user', collection='test_collection', s=None, - p=Value(value='test_predicate', is_uri=False), - o=Value(value='test_object', is_uri=False), + p=Term(type=LITERAL, value='test_predicate'), + o=Term(type=LITERAL, value='test_object'), limit=50 ) @@ -569,7 +633,7 @@ class TestCassandraQueryPerformanceOptimizations: # Verify get_po was called (should use optimized po_table) mock_tg_instance.get_po.assert_called_once_with( - 'test_collection', 'test_predicate', 'test_object', limit=50 + 'test_collection', 'test_predicate', 'test_object', g=None, limit=50 ) assert len(result) == 1 @@ -578,16 +642,20 @@ class TestCassandraQueryPerformanceOptimizations: assert result[0].o.value == 'test_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_get_os_query_optimization(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_get_os_query_optimization(self, mock_kg_class): """Test that get_os queries use optimized table (no ALLOW FILTERING)""" - from trustgraph.schema import TriplesQueryRequest, Value + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance mock_result = MagicMock() mock_result.p = 'result_predicate' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_tg_instance.get_os.return_value = [mock_result] processor = Processor(taskgroup=MagicMock()) @@ -596,9 +664,9 @@ class TestCassandraQueryPerformanceOptimizations: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value='test_subject', is_uri=False), + s=Term(type=LITERAL, value='test_subject'), p=None, - o=Value(value='test_object', is_uri=False), + o=Term(type=LITERAL, value='test_object'), limit=25 ) @@ -606,7 +674,7 @@ class TestCassandraQueryPerformanceOptimizations: # Verify get_os was called (should use optimized subject_table with clustering) mock_tg_instance.get_os.assert_called_once_with( - 'test_collection', 'test_object', 'test_subject', limit=25 + 'test_collection', 'test_object', 'test_subject', g=None, limit=25 ) assert len(result) == 1 @@ -615,13 +683,13 @@ class TestCassandraQueryPerformanceOptimizations: assert result[0].o.value == 'test_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_all_query_patterns_use_correct_tables(self, mock_kg_class): """Test that all query patterns route to their optimal tables""" - from trustgraph.schema import TriplesQueryRequest, Value + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance # Mock empty results for all queries mock_tg_instance.get_all.return_value = [] @@ -655,9 +723,9 @@ class TestCassandraQueryPerformanceOptimizations: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value=s, is_uri=False) if s else None, - p=Value(value=p, is_uri=False) if p else None, - o=Value(value=o, is_uri=False) if o else None, + s=Term(type=LITERAL, value=s) if s else None, + p=Term(type=LITERAL, value=p) if p else None, + o=Term(type=LITERAL, value=o) if o else None, limit=10 ) @@ -687,19 +755,23 @@ class TestCassandraQueryPerformanceOptimizations: # Mode is determined in KnowledgeGraph initialization @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') - async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_performance_critical_po_query_no_filtering(self, mock_kg_class): """Test the performance-critical PO query that eliminates ALLOW FILTERING""" - from trustgraph.schema import TriplesQueryRequest, Value + from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance # Mock multiple subjects for the same predicate-object pair mock_results = [] for i in range(5): mock_result = MagicMock() mock_result.s = f'subject_{i}' + mock_result.g = '' + mock_result.otype = None + mock_result.dtype = None + mock_result.lang = None mock_results.append(mock_result) mock_tg_instance.get_po.return_value = mock_results @@ -711,8 +783,8 @@ class TestCassandraQueryPerformanceOptimizations: user='large_dataset_user', collection='massive_collection', s=None, - p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True), - o=Value(value='http://example.com/Person', is_uri=True), + p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'), + o=Term(type=IRI, iri='http://example.com/Person'), limit=1000 ) @@ -723,14 +795,15 @@ class TestCassandraQueryPerformanceOptimizations: 'massive_collection', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://example.com/Person', + g=None, limit=1000 ) # Verify all results were returned assert len(result) == 5 for i, triple in enumerate(result): - assert triple.s.value == f'subject_{i}' - assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' - assert triple.p.is_uri is True - assert triple.o.value == 'http://example.com/Person' - assert triple.o.is_uri is True \ No newline at end of file + assert triple.s.value == f'subject_{i}' # Mock returns literal values + assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' + assert triple.p.type == IRI + assert triple.o.iri == 'http://example.com/Person' # URIs use .iri + assert triple.o.type == IRI \ No newline at end of file diff --git a/tests/unit/test_query/test_triples_falkordb_query.py b/tests/unit/test_query/test_triples_falkordb_query.py index 3e7d07db..d5c047d7 100644 --- a/tests/unit/test_query/test_triples_falkordb_query.py +++ b/tests/unit/test_query/test_triples_falkordb_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.triples.falkordb.service import Processor -from trustgraph.schema import Value, TriplesQueryRequest +from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL class TestFalkorDBQueryProcessor: @@ -25,50 +25,50 @@ class TestFalkorDBQueryProcessor: def test_create_value_with_http_uri(self, processor): """Test create_value with HTTP URI""" result = processor.create_value("http://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI def test_create_value_with_https_uri(self, processor): """Test create_value with HTTPS URI""" result = processor.create_value("https://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "https://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "https://example.com/resource" + assert result.type == IRI def test_create_value_with_literal(self, processor): """Test create_value with literal value""" result = processor.create_value("just a literal string") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "just a literal string" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_empty_string(self, processor): """Test create_value with empty string""" result = processor.create_value("") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_partial_uri(self, processor): """Test create_value with string that looks like URI but isn't complete""" result = processor.create_value("http") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "http" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_ftp_uri(self, processor): """Test create_value with FTP URI (should not be detected as URI)""" result = processor.create_value("ftp://example.com/file") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "ftp://example.com/file" - assert result.is_uri is False + assert result.type == LITERAL @patch('trustgraph.query.triples.falkordb.service.FalkorDB') def test_processor_initialization_with_defaults(self, mock_falkordb): @@ -125,9 +125,9 @@ class TestFalkorDBQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal object", is_uri=False), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -138,8 +138,8 @@ class TestFalkorDBQueryProcessor: # Verify result contains the queried triple (appears twice - once from each query) assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal object" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @@ -166,8 +166,8 @@ class TestFalkorDBQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None, limit=100 ) @@ -179,13 +179,13 @@ class TestFalkorDBQueryProcessor: # Verify results contain different objects assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal result" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/predicate" - assert result[1].o.value == "http://example.com/uri_result" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/predicate" + assert result[1].o.iri == "http://example.com/uri_result" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @pytest.mark.asyncio @@ -211,9 +211,9 @@ class TestFalkorDBQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, - o=Value(value="literal object", is_uri=False), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -224,12 +224,12 @@ class TestFalkorDBQueryProcessor: # Verify results contain different predicates assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/pred2" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/pred2" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @@ -256,7 +256,7 @@ class TestFalkorDBQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, o=None, limit=100 @@ -269,13 +269,13 @@ class TestFalkorDBQueryProcessor: # Verify results contain different predicate-object pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/pred2" - assert result[1].o.value == "http://example.com/uri2" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/pred2" + assert result[1].o.iri == "http://example.com/uri2" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @pytest.mark.asyncio @@ -302,8 +302,8 @@ class TestFalkorDBQueryProcessor: user='test_user', collection='test_collection', s=None, - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal object", is_uri=False), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -314,12 +314,12 @@ class TestFalkorDBQueryProcessor: # Verify results contain different subjects assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/predicate" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/predicate" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @@ -347,7 +347,7 @@ class TestFalkorDBQueryProcessor: user='test_user', collection='test_collection', s=None, - p=Value(value="http://example.com/predicate", is_uri=True), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None, limit=100 ) @@ -359,13 +359,13 @@ class TestFalkorDBQueryProcessor: # Verify results contain different subject-object pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/predicate" - assert result[1].o.value == "http://example.com/uri2" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/predicate" + assert result[1].o.iri == "http://example.com/uri2" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @pytest.mark.asyncio @@ -393,7 +393,7 @@ class TestFalkorDBQueryProcessor: collection='test_collection', s=None, p=None, - o=Value(value="literal object", is_uri=False), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -404,12 +404,12 @@ class TestFalkorDBQueryProcessor: # Verify results contain different subject-predicate pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/pred2" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/pred2" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @@ -449,13 +449,13 @@ class TestFalkorDBQueryProcessor: # Verify results contain different triples assert len(result) == 2 - assert result[0].s.value == "http://example.com/s1" - assert result[0].p.value == "http://example.com/p1" + assert result[0].s.iri == "http://example.com/s1" + assert result[0].p.iri == "http://example.com/p1" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/s2" - assert result[1].p.value == "http://example.com/p2" - assert result[1].o.value == "http://example.com/o2" + assert result[1].s.iri == "http://example.com/s2" + assert result[1].p.iri == "http://example.com/p2" + assert result[1].o.iri == "http://example.com/o2" @patch('trustgraph.query.triples.falkordb.service.FalkorDB') @pytest.mark.asyncio @@ -476,7 +476,7 @@ class TestFalkorDBQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, o=None, limit=100 diff --git a/tests/unit/test_query/test_triples_memgraph_query.py b/tests/unit/test_query/test_triples_memgraph_query.py index bd394ae4..f4222af1 100644 --- a/tests/unit/test_query/test_triples_memgraph_query.py +++ b/tests/unit/test_query/test_triples_memgraph_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.triples.memgraph.service import Processor -from trustgraph.schema import Value, TriplesQueryRequest +from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL class TestMemgraphQueryProcessor: @@ -25,50 +25,50 @@ class TestMemgraphQueryProcessor: def test_create_value_with_http_uri(self, processor): """Test create_value with HTTP URI""" result = processor.create_value("http://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI def test_create_value_with_https_uri(self, processor): """Test create_value with HTTPS URI""" result = processor.create_value("https://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "https://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "https://example.com/resource" + assert result.type == IRI def test_create_value_with_literal(self, processor): """Test create_value with literal value""" result = processor.create_value("just a literal string") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "just a literal string" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_empty_string(self, processor): """Test create_value with empty string""" result = processor.create_value("") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_partial_uri(self, processor): """Test create_value with string that looks like URI but isn't complete""" result = processor.create_value("http") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "http" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_ftp_uri(self, processor): """Test create_value with FTP URI (should not be detected as URI)""" result = processor.create_value("ftp://example.com/file") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "ftp://example.com/file" - assert result.is_uri is False + assert result.type == LITERAL @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') def test_processor_initialization_with_defaults(self, mock_graph_db): @@ -124,9 +124,9 @@ class TestMemgraphQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal object", is_uri=False), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -137,8 +137,8 @@ class TestMemgraphQueryProcessor: # Verify result contains the queried triple (appears twice - once from each query) assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal object" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @@ -166,8 +166,8 @@ class TestMemgraphQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None, limit=100 ) @@ -179,13 +179,13 @@ class TestMemgraphQueryProcessor: # Verify results contain different objects assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal result" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/predicate" - assert result[1].o.value == "http://example.com/uri_result" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/predicate" + assert result[1].o.iri == "http://example.com/uri_result" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio @@ -212,9 +212,9 @@ class TestMemgraphQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, - o=Value(value="literal object", is_uri=False), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -225,12 +225,12 @@ class TestMemgraphQueryProcessor: # Verify results contain different predicates assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/pred2" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/pred2" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @@ -258,7 +258,7 @@ class TestMemgraphQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, o=None, limit=100 @@ -271,13 +271,13 @@ class TestMemgraphQueryProcessor: # Verify results contain different predicate-object pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/pred2" - assert result[1].o.value == "http://example.com/uri2" + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/pred2" + assert result[1].o.iri == "http://example.com/uri2" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio @@ -305,8 +305,8 @@ class TestMemgraphQueryProcessor: user='test_user', collection='test_collection', s=None, - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal object", is_uri=False), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -317,12 +317,12 @@ class TestMemgraphQueryProcessor: # Verify results contain different subjects assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/predicate" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/predicate" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @@ -351,7 +351,7 @@ class TestMemgraphQueryProcessor: user='test_user', collection='test_collection', s=None, - p=Value(value="http://example.com/predicate", is_uri=True), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None, limit=100 ) @@ -363,13 +363,13 @@ class TestMemgraphQueryProcessor: # Verify results contain different subject-object pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/predicate" - assert result[1].o.value == "http://example.com/uri2" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/predicate" + assert result[1].o.iri == "http://example.com/uri2" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio @@ -398,7 +398,7 @@ class TestMemgraphQueryProcessor: collection='test_collection', s=None, p=None, - o=Value(value="literal object", is_uri=False), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -409,12 +409,12 @@ class TestMemgraphQueryProcessor: # Verify results contain different subject-predicate pairs assert len(result) == 2 - assert result[0].s.value == "http://example.com/subj1" - assert result[0].p.value == "http://example.com/pred1" + assert result[0].s.iri == "http://example.com/subj1" + assert result[0].p.iri == "http://example.com/pred1" assert result[0].o.value == "literal object" - assert result[1].s.value == "http://example.com/subj2" - assert result[1].p.value == "http://example.com/pred2" + assert result[1].s.iri == "http://example.com/subj2" + assert result[1].p.iri == "http://example.com/pred2" assert result[1].o.value == "literal object" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @@ -455,13 +455,13 @@ class TestMemgraphQueryProcessor: # Verify results contain different triples assert len(result) == 2 - assert result[0].s.value == "http://example.com/s1" - assert result[0].p.value == "http://example.com/p1" + assert result[0].s.iri == "http://example.com/s1" + assert result[0].p.iri == "http://example.com/p1" assert result[0].o.value == "literal1" - assert result[1].s.value == "http://example.com/s2" - assert result[1].p.value == "http://example.com/p2" - assert result[1].o.value == "http://example.com/o2" + assert result[1].s.iri == "http://example.com/s2" + assert result[1].p.iri == "http://example.com/p2" + assert result[1].o.iri == "http://example.com/o2" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio @@ -480,7 +480,7 @@ class TestMemgraphQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, o=None, limit=100 diff --git a/tests/unit/test_query/test_triples_neo4j_query.py b/tests/unit/test_query/test_triples_neo4j_query.py index 320aed54..e379ed21 100644 --- a/tests/unit/test_query/test_triples_neo4j_query.py +++ b/tests/unit/test_query/test_triples_neo4j_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.triples.neo4j.service import Processor -from trustgraph.schema import Value, TriplesQueryRequest +from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL class TestNeo4jQueryProcessor: @@ -25,50 +25,50 @@ class TestNeo4jQueryProcessor: def test_create_value_with_http_uri(self, processor): """Test create_value with HTTP URI""" result = processor.create_value("http://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "http://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "http://example.com/resource" + assert result.type == IRI def test_create_value_with_https_uri(self, processor): """Test create_value with HTTPS URI""" result = processor.create_value("https://example.com/resource") - - assert isinstance(result, Value) - assert result.value == "https://example.com/resource" - assert result.is_uri is True + + assert isinstance(result, Term) + assert result.iri == "https://example.com/resource" + assert result.type == IRI def test_create_value_with_literal(self, processor): """Test create_value with literal value""" result = processor.create_value("just a literal string") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "just a literal string" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_empty_string(self, processor): """Test create_value with empty string""" result = processor.create_value("") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_partial_uri(self, processor): """Test create_value with string that looks like URI but isn't complete""" result = processor.create_value("http") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "http" - assert result.is_uri is False + assert result.type == LITERAL def test_create_value_with_ftp_uri(self, processor): """Test create_value with FTP URI (should not be detected as URI)""" result = processor.create_value("ftp://example.com/file") - - assert isinstance(result, Value) + + assert isinstance(result, Term) assert result.value == "ftp://example.com/file" - assert result.is_uri is False + assert result.type == LITERAL @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') def test_processor_initialization_with_defaults(self, mock_graph_db): @@ -124,9 +124,9 @@ class TestNeo4jQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal object", is_uri=False), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal object"), limit=100 ) @@ -137,8 +137,8 @@ class TestNeo4jQueryProcessor: # Verify result contains the queried triple (appears twice - once from each query) assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal object" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @@ -166,8 +166,8 @@ class TestNeo4jQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None, limit=100 ) @@ -179,13 +179,13 @@ class TestNeo4jQueryProcessor: # Verify results contain different objects assert len(result) == 2 - assert result[0].s.value == "http://example.com/subject" - assert result[0].p.value == "http://example.com/predicate" + assert result[0].s.iri == "http://example.com/subject" + assert result[0].p.iri == "http://example.com/predicate" assert result[0].o.value == "literal result" - - assert result[1].s.value == "http://example.com/subject" - assert result[1].p.value == "http://example.com/predicate" - assert result[1].o.value == "http://example.com/uri_result" + + assert result[1].s.iri == "http://example.com/subject" + assert result[1].p.iri == "http://example.com/predicate" + assert result[1].o.iri == "http://example.com/uri_result" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio @@ -225,13 +225,13 @@ class TestNeo4jQueryProcessor: # Verify results contain different triples assert len(result) == 2 - assert result[0].s.value == "http://example.com/s1" - assert result[0].p.value == "http://example.com/p1" + assert result[0].s.iri == "http://example.com/s1" + assert result[0].p.iri == "http://example.com/p1" assert result[0].o.value == "literal1" - - assert result[1].s.value == "http://example.com/s2" - assert result[1].p.value == "http://example.com/p2" - assert result[1].o.value == "http://example.com/o2" + + assert result[1].s.iri == "http://example.com/s2" + assert result[1].p.iri == "http://example.com/p2" + assert result[1].o.iri == "http://example.com/o2" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio @@ -250,12 +250,12 @@ class TestNeo4jQueryProcessor: query = TriplesQueryRequest( user='test_user', collection='test_collection', - s=Value(value="http://example.com/subject", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), p=None, o=None, limit=100 ) - + # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): await processor.query_triples(query) diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index 27c09ca4..76bf5b08 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.schema import ( StructuredQueryRequest, StructuredQueryResponse, QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, - ObjectsQueryRequest, ObjectsQueryResponse, + RowsQueryRequest, RowsQueryResponse, Error, GraphQLError ) from trustgraph.retrieval.structured_query.service import Processor @@ -68,7 +68,7 @@ class TestStructuredQueryProcessor: ) # Mock objects query service response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}', errors=None, @@ -86,7 +86,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -108,7 +108,7 @@ class TestStructuredQueryProcessor: # Verify objects query service was called correctly mock_objects_client.request.assert_called_once() objects_call_args = mock_objects_client.request.call_args[0][0] - assert isinstance(objects_call_args, ObjectsQueryRequest) + assert isinstance(objects_call_args, RowsQueryRequest) assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }' assert objects_call_args.variables == {"state": "NY"} assert objects_call_args.user == "trustgraph" @@ -224,7 +224,7 @@ class TestStructuredQueryProcessor: assert response.error is not None assert "empty GraphQL query" in response.error.message - async def test_objects_query_service_error(self, processor): + async def test_rows_query_service_error(self, processor): """Test handling of objects query service errors""" # Arrange request = StructuredQueryRequest( @@ -250,7 +250,7 @@ class TestStructuredQueryProcessor: ) # Mock objects query service error - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=Error(type="graphql-execution-error", message="Table 'customers' not found"), data=None, errors=None, @@ -267,7 +267,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -284,7 +284,7 @@ class TestStructuredQueryProcessor: response = response_call[0][0] assert response.error is not None - assert "Objects query service error" in response.error.message + assert "Rows query service error" in response.error.message assert "Table 'customers' not found" in response.error.message async def test_graphql_errors_handling(self, processor): @@ -321,7 +321,7 @@ class TestStructuredQueryProcessor: ) ] - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, errors=graphql_errors, @@ -338,7 +338,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -400,7 +400,7 @@ class TestStructuredQueryProcessor: ) # Mock objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}', errors=None @@ -416,7 +416,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -464,7 +464,7 @@ class TestStructuredQueryProcessor: confidence=0.9 ) - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, # Null data errors=None, @@ -481,7 +481,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response diff --git a/tests/unit/test_storage/test_cassandra_config_integration.py b/tests/unit/test_storage/test_cassandra_config_integration.py index 754a4bb0..0956f4e7 100644 --- a/tests/unit/test_storage/test_cassandra_config_integration.py +++ b/tests/unit/test_storage/test_cassandra_config_integration.py @@ -10,7 +10,7 @@ import pytest from unittest.mock import Mock, patch, MagicMock from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter -from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter +from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery from trustgraph.storage.knowledge.store import Processor as KgStore @@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration: assert processor.cassandra_password is None -class TestObjectsWriterConfiguration: +class TestRowsWriterConfiguration: """Test Cassandra configuration in objects writer processor.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_environment_variable_configuration(self, mock_cluster): """Test processor picks up configuration from environment variables.""" env_vars = { @@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2'] assert processor.cassandra_username == 'obj-env-user' assert processor.cassandra_password == 'obj-env-pass' - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_cassandra_connection_with_hosts_list(self, mock_cluster): """Test that Cassandra connection uses hosts list correctly.""" env_vars = { @@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify cluster was called with hosts list @@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration: assert 'contact_points' in call_args.kwargs assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3'] - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_authentication_configuration(self, mock_auth_provider, mock_cluster): """Test authentication is configured when credentials are provided.""" env_vars = { @@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify auth provider was created with correct credentials @@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling: def test_objects_writer_add_args(self): """Test that objects writer adds standard Cassandra arguments.""" import argparse - from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter + from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter parser = argparse.ArgumentParser() - ObjectsWriter.add_args(parser) + RowsWriter.add_args(parser) # Parse empty args to check that arguments exist args = parser.parse_args([]) diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index a22173ab..8a8e1090 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.storage.graph_embeddings.milvus.write import Processor -from trustgraph.schema import Value, EntityEmbeddings +from trustgraph.schema import Term, EntityEmbeddings, IRI, LITERAL class TestMilvusGraphEmbeddingsStorageProcessor: @@ -22,11 +22,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: # Create test entities with embeddings entity1 = EntityEmbeddings( - entity=Value(value='http://example.com/entity1', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/entity1'), vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] ) entity2 = EntityEmbeddings( - entity=Value(value='literal entity', is_uri=False), + entity=Term(type=LITERAL, value='literal entity'), vectors=[[0.7, 0.8, 0.9]] ) message.entities = [entity1, entity2] @@ -84,7 +84,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' entity = EntityEmbeddings( - entity=Value(value='http://example.com/entity', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/entity'), vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] ) message.entities = [entity] @@ -136,7 +136,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' entity = EntityEmbeddings( - entity=Value(value='', is_uri=False), + entity=Term(type=LITERAL, value=''), vectors=[[0.1, 0.2, 0.3]] ) message.entities = [entity] @@ -155,7 +155,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' entity = EntityEmbeddings( - entity=Value(value=None, is_uri=False), + entity=Term(type=LITERAL, value=None), vectors=[[0.1, 0.2, 0.3]] ) message.entities = [entity] @@ -174,15 +174,15 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' valid_entity = EntityEmbeddings( - entity=Value(value='http://example.com/valid', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/valid'), vectors=[[0.1, 0.2, 0.3]] ) empty_entity = EntityEmbeddings( - entity=Value(value='', is_uri=False), + entity=Term(type=LITERAL, value=''), vectors=[[0.4, 0.5, 0.6]] ) none_entity = EntityEmbeddings( - entity=Value(value=None, is_uri=False), + entity=Term(type=LITERAL, value=None), vectors=[[0.7, 0.8, 0.9]] ) message.entities = [valid_entity, empty_entity, none_entity] @@ -217,7 +217,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' entity = EntityEmbeddings( - entity=Value(value='http://example.com/entity', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/entity'), vectors=[] ) message.entities = [entity] @@ -236,7 +236,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' entity = EntityEmbeddings( - entity=Value(value='http://example.com/entity', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/entity'), vectors=[ [0.1, 0.2], # 2D vector [0.3, 0.4, 0.5, 0.6], # 4D vector @@ -269,11 +269,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' uri_entity = EntityEmbeddings( - entity=Value(value='http://example.com/uri_entity', is_uri=True), + entity=Term(type=IRI, iri='http://example.com/uri_entity'), vectors=[[0.1, 0.2, 0.3]] ) literal_entity = EntityEmbeddings( - entity=Value(value='literal entity text', is_uri=False), + entity=Term(type=LITERAL, value='literal entity text'), vectors=[[0.4, 0.5, 0.6]] ) message.entities = [uri_entity, literal_entity] diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index d240b892..8b1a710a 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.storage.graph_embeddings.qdrant.write import Processor +from trustgraph.schema import IRI, LITERAL class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @@ -67,7 +68,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'test_collection' mock_entity = MagicMock() - mock_entity.entity.value = 'test_entity' + mock_entity.entity.type = IRI + mock_entity.entity.iri = 'test_entity' mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions mock_message.entities = [mock_entity] @@ -120,11 +122,13 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'multi_collection' mock_entity1 = MagicMock() - mock_entity1.entity.value = 'entity_one' + mock_entity1.entity.type = IRI + mock_entity1.entity.iri = 'entity_one' mock_entity1.vectors = [[0.1, 0.2]] - + mock_entity2 = MagicMock() - mock_entity2.entity.value = 'entity_two' + mock_entity2.entity.type = IRI + mock_entity2.entity.iri = 'entity_two' mock_entity2.vectors = [[0.3, 0.4]] mock_message.entities = [mock_entity1, mock_entity2] @@ -179,7 +183,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'vector_collection' mock_entity = MagicMock() - mock_entity.entity.value = 'multi_vector_entity' + mock_entity.entity.type = IRI + mock_entity.entity.iri = 'multi_vector_entity' mock_entity.vectors = [ [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], @@ -231,11 +236,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'empty_collection' mock_entity_empty = MagicMock() + mock_entity_empty.entity.type = LITERAL mock_entity_empty.entity.value = "" # Empty string mock_entity_empty.vectors = [[0.1, 0.2]] - + mock_entity_none = MagicMock() - mock_entity_none.entity.value = None # None value + mock_entity_none.entity = None # None entity mock_entity_none.vectors = [[0.3, 0.4]] mock_message.entities = [mock_entity_empty, mock_entity_none] diff --git a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py b/tests/unit/test_storage/test_neo4j_user_collection_isolation.py index bc8bb03f..dce170a7 100644 --- a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py +++ b/tests/unit/test_storage/test_neo4j_user_collection_isolation.py @@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch, call from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor -from trustgraph.schema import Triples, Triple, Value, Metadata +from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL from trustgraph.schema import TriplesQueryRequest @@ -60,9 +60,9 @@ class TestNeo4jUserCollectionIsolation: ) triple = Triple( - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="literal_value", is_uri=False) + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="literal_value") ) message = Triples( @@ -128,9 +128,9 @@ class TestNeo4jUserCollectionIsolation: metadata = Metadata(id="test-id") triple = Triple( - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="http://example.com/object", is_uri=True) + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=IRI, iri="http://example.com/object") ) message = Triples( @@ -170,8 +170,8 @@ class TestNeo4jUserCollectionIsolation: query = TriplesQueryRequest( user="test_user", collection="test_collection", - s=Value(value="http://example.com/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), + s=Term(type=IRI, iri="http://example.com/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), o=None ) @@ -254,9 +254,9 @@ class TestNeo4jUserCollectionIsolation: metadata=Metadata(user="user1", collection="coll1"), triples=[ Triple( - s=Value(value="http://example.com/user1/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="user1_data", is_uri=False) + s=Term(type=IRI, iri="http://example.com/user1/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="user1_data") ) ] ) @@ -265,9 +265,9 @@ class TestNeo4jUserCollectionIsolation: metadata=Metadata(user="user2", collection="coll2"), triples=[ Triple( - s=Value(value="http://example.com/user2/subject", is_uri=True), - p=Value(value="http://example.com/predicate", is_uri=True), - o=Value(value="user2_data", is_uri=False) + s=Term(type=IRI, iri="http://example.com/user2/subject"), + p=Term(type=IRI, iri="http://example.com/predicate"), + o=Term(type=LITERAL, value="user2_data") ) ] ) @@ -429,9 +429,9 @@ class TestNeo4jUserCollectionRegression: metadata=Metadata(user="user1", collection="coll1"), triples=[ Triple( - s=Value(value=shared_uri, is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="user1_value", is_uri=False) + s=Term(type=IRI, iri=shared_uri), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="user1_value") ) ] ) @@ -440,9 +440,9 @@ class TestNeo4jUserCollectionRegression: metadata=Metadata(user="user2", collection="coll2"), triples=[ Triple( - s=Value(value=shared_uri, is_uri=True), - p=Value(value="http://example.com/p", is_uri=True), - o=Value(value="user2_value", is_uri=False) + s=Term(type=IRI, iri=shared_uri), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="user2_value") ) ] ) diff --git a/tests/unit/test_storage/test_objects_cassandra_storage.py b/tests/unit/test_storage/test_objects_cassandra_storage.py deleted file mode 100644 index c7f5ff40..00000000 --- a/tests/unit/test_storage/test_objects_cassandra_storage.py +++ /dev/null @@ -1,533 +0,0 @@ -""" -Unit tests for Cassandra Object Storage Processor - -Tests the business logic of the object storage processor including: -- Schema configuration handling -- Type conversions -- Name sanitization -- Table structure generation -""" - -import pytest -from unittest.mock import MagicMock, AsyncMock, patch -import json - -from trustgraph.storage.objects.cassandra.write import Processor -from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field - - -class TestObjectsCassandraStorageLogic: - """Test business logic without FlowProcessor dependencies""" - - def test_sanitize_name(self): - """Test name sanitization for Cassandra compatibility""" - processor = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - - # Test various name patterns (back to original logic) - assert processor.sanitize_name("simple_name") == "simple_name" - assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes" - assert processor.sanitize_name("name.with.dots") == "name_with_dots" - assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number" - assert processor.sanitize_name("name with spaces") == "name_with_spaces" - assert processor.sanitize_name("special!@#$%^chars") == "special______chars" - - def test_get_cassandra_type(self): - """Test field type conversion to Cassandra types""" - processor = MagicMock() - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - - # Basic type mappings - assert processor.get_cassandra_type("string") == "text" - assert processor.get_cassandra_type("boolean") == "boolean" - assert processor.get_cassandra_type("timestamp") == "timestamp" - assert processor.get_cassandra_type("uuid") == "uuid" - - # Integer types with size hints - assert processor.get_cassandra_type("integer", size=2) == "int" - assert processor.get_cassandra_type("integer", size=8) == "bigint" - - # Float types with size hints - assert processor.get_cassandra_type("float", size=2) == "float" - assert processor.get_cassandra_type("float", size=8) == "double" - - # Unknown type defaults to text - assert processor.get_cassandra_type("unknown_type") == "text" - - def test_convert_value(self): - """Test value conversion for different field types""" - processor = MagicMock() - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - - # Integer conversions - assert processor.convert_value("123", "integer") == 123 - assert processor.convert_value(123.5, "integer") == 123 - assert processor.convert_value(None, "integer") is None - - # Float conversions - assert processor.convert_value("123.45", "float") == 123.45 - assert processor.convert_value(123, "float") == 123.0 - - # Boolean conversions - assert processor.convert_value("true", "boolean") is True - assert processor.convert_value("false", "boolean") is False - assert processor.convert_value("1", "boolean") is True - assert processor.convert_value("0", "boolean") is False - assert processor.convert_value("yes", "boolean") is True - assert processor.convert_value("no", "boolean") is False - - # String conversions - assert processor.convert_value(123, "string") == "123" - assert processor.convert_value(True, "string") == "True" - - def test_table_creation_cql_generation(self): - """Test CQL generation for table creation""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create test schema - schema = RowSchema( - name="customer_records", - description="Test customer schema", - fields=[ - Field( - name="customer_id", - type="string", - size=50, - primary=True, - required=True, - indexed=False - ), - Field( - name="email", - type="string", - size=100, - required=True, - indexed=True - ), - Field( - name="age", - type="integer", - size=4, - required=False, - indexed=False - ) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "customer_records", schema) - - # Verify keyspace was ensured (check that it was added to known_keyspaces) - assert "test_user" in processor.known_keyspaces - - # Check the CQL that was executed (first call should be table creation) - all_calls = processor.session.execute.call_args_list - table_creation_cql = all_calls[0][0][0] # First call - - # Verify table structure (keyspace uses sanitize_name, table uses sanitize_table) - assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql - assert "collection text" in table_creation_cql - assert "customer_id text" in table_creation_cql - assert "email text" in table_creation_cql - assert "age int" in table_creation_cql - assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql - - def test_table_creation_without_primary_key(self): - """Test table creation when no primary key is defined""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create schema without primary key - schema = RowSchema( - name="events", - description="Event log", - fields=[ - Field(name="event_type", type="string", size=50), - Field(name="timestamp", type="timestamp", size=0) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "events", schema) - - # Check the CQL includes synthetic_id (field names don't get o_ prefix) - executed_cql = processor.session.execute.call_args[0][0] - assert "synthetic_id uuid" in executed_cql - assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql - - @pytest.mark.asyncio - async def test_schema_config_parsing(self): - """Test parsing of schema configurations""" - processor = MagicMock() - processor.schemas = {} - processor.config_key = "schema" - processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - - # Create test configuration - config = { - "schema": { - "customer_records": json.dumps({ - "name": "customer_records", - "description": "Customer data", - "fields": [ - { - "name": "id", - "type": "string", - "primary_key": True, - "required": True - }, - { - "name": "name", - "type": "string", - "required": True - }, - { - "name": "balance", - "type": "float", - "size": 8 - } - ] - }) - } - } - - # Process configuration - await processor.on_schema_config(config, version=1) - - # Verify schema was loaded - assert "customer_records" in processor.schemas - schema = processor.schemas["customer_records"] - assert schema.name == "customer_records" - assert len(schema.fields) == 3 - - # Check field properties - id_field = schema.fields[0] - assert id_field.name == "id" - assert id_field.type == "string" - assert id_field.primary is True - # Note: Field.required always returns False due to Pulsar schema limitations - # The actual required value is tracked during schema parsing - - @pytest.mark.asyncio - async def test_object_processing_logic(self): - """Test the logic for processing ExtractedObject""" - processor = MagicMock() - processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="value", type="integer", size=4) - ] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create test object - test_obj = ExtractedObject( - metadata=Metadata( - id="test-001", - user="test_user", - collection="test_collection", - metadata=[] - ), - schema_name="test_schema", - values=[{"id": "123", "value": "456"}], - confidence=0.9, - source_span="test source" - ) - - # Create mock message - msg = MagicMock() - msg.value.return_value = test_obj - - # Process object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"]) - - # Verify insert was executed (keyspace normal, table with o_ prefix) - processor.session.execute.assert_called_once() - insert_cql = processor.session.execute.call_args[0][0] - values = processor.session.execute.call_args[0][1] - - assert "INSERT INTO test_user.o_test_schema" in insert_cql - assert "collection" in insert_cql - assert values[0] == "test_collection" # collection value - assert values[1] == "123" # id value (from values[0]) - assert values[2] == 456 # converted integer value (from values[0]) - - def test_secondary_index_creation(self): - """Test that secondary indexes are created for indexed fields""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - if keyspace not in processor.known_tables: - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create schema with indexed field - schema = RowSchema( - name="products", - description="Product catalog", - fields=[ - Field(name="product_id", type="string", size=50, primary=True), - Field(name="category", type="string", size=30, indexed=True), - Field(name="price", type="float", size=8, indexed=True) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "products", schema) - - # Should have 3 calls: create table + 2 indexes - assert processor.session.execute.call_count == 3 - - # Check index creation calls (table has o_ prefix, fields don't) - calls = processor.session.execute.call_args_list - index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]] - assert len(index_calls) == 2 - assert any("o_products_category_idx" in call for call in index_calls) - assert any("o_products_price_idx" in call for call in index_calls) - - -class TestObjectsCassandraStorageBatchLogic: - """Test batch processing logic in Cassandra storage""" - - @pytest.mark.asyncio - async def test_batch_object_processing_logic(self): - """Test processing of batch ExtractedObjects""" - processor = MagicMock() - processor.schemas = { - "batch_schema": RowSchema( - name="batch_schema", - description="Test batch schema", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="name", type="string", size=100), - Field(name="value", type="integer", size=4) - ] - ) - } - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - - # Create batch object with multiple values - batch_obj = ExtractedObject( - metadata=Metadata( - id="batch-001", - user="test_user", - collection="batch_collection", - metadata=[] - ), - schema_name="batch_schema", - values=[ - {"id": "001", "name": "First", "value": "100"}, - {"id": "002", "name": "Second", "value": "200"}, - {"id": "003", "name": "Third", "value": "300"} - ], - confidence=0.95, - source_span="batch source" - ) - - # Create mock message - msg = MagicMock() - msg.value.return_value = batch_obj - - # Process batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured once - processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"]) - - # Verify 3 separate insert calls (one per batch item) - assert processor.session.execute.call_count == 3 - - # Check each insert call - calls = processor.session.execute.call_args_list - for i, call in enumerate(calls): - insert_cql = call[0][0] - values = call[0][1] - - assert "INSERT INTO test_user.o_batch_schema" in insert_cql - assert "collection" in insert_cql - - # Check values for each batch item - assert values[0] == "batch_collection" # collection - assert values[1] == f"00{i+1}" # id from batch item i - assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name - assert values[3] == (i+1) * 100 # converted integer value - - @pytest.mark.asyncio - async def test_empty_batch_processing_logic(self): - """Test processing of empty batch ExtractedObjects""" - processor = MagicMock() - processor.schemas = { - "empty_schema": RowSchema( - name="empty_schema", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create empty batch object - empty_batch_obj = ExtractedObject( - metadata=Metadata( - id="empty-001", - user="test_user", - collection="empty_collection", - metadata=[] - ), - schema_name="empty_schema", - values=[], # Empty batch - confidence=1.0, - source_span="empty source" - ) - - msg = MagicMock() - msg.value.return_value = empty_batch_obj - - # Process empty batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once() - - # Verify no insert calls for empty batch - processor.session.execute.assert_not_called() - - @pytest.mark.asyncio - async def test_single_item_batch_processing_logic(self): - """Test processing of single-item batch (backward compatibility)""" - processor = MagicMock() - processor.schemas = { - "single_schema": RowSchema( - name="single_schema", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="data", type="string", size=100) - ] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create single-item batch object (backward compatibility case) - single_batch_obj = ExtractedObject( - metadata=Metadata( - id="single-001", - user="test_user", - collection="single_collection", - metadata=[] - ), - schema_name="single_schema", - values=[{"id": "single-1", "data": "single data"}], # Array with one item - confidence=0.8, - source_span="single source" - ) - - msg = MagicMock() - msg.value.return_value = single_batch_obj - - # Process single-item batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once() - - # Verify exactly one insert call - processor.session.execute.assert_called_once() - - insert_cql = processor.session.execute.call_args[0][0] - values = processor.session.execute.call_args[0][1] - - assert "INSERT INTO test_user.o_single_schema" in insert_cql - assert values[0] == "single_collection" # collection - assert values[1] == "single-1" # id value - assert values[2] == "single data" # data value - - def test_batch_value_conversion_logic(self): - """Test value conversion works correctly for batch items""" - processor = MagicMock() - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - - # Test various conversion scenarios that would occur in batch processing - test_cases = [ - # Integer conversions for batch items - ("123", "integer", 123), - ("456", "integer", 456), - ("789", "integer", 789), - # Float conversions for batch items - ("12.5", "float", 12.5), - ("34.7", "float", 34.7), - # Boolean conversions for batch items - ("true", "boolean", True), - ("false", "boolean", False), - ("1", "boolean", True), - ("0", "boolean", False), - # String conversions for batch items - (123, "string", "123"), - (45.6, "string", "45.6"), - ] - - for input_val, field_type, expected_output in test_cases: - result = processor.convert_value(input_val, field_type) - assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}" \ No newline at end of file diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py new file mode 100644 index 00000000..b4c5a5b4 --- /dev/null +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -0,0 +1,435 @@ +""" +Unit tests for trustgraph.storage.row_embeddings.qdrant.write +Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + + +class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): + """Test Qdrant row embeddings storage functionality""" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_processor_initialization_basic(self, mock_qdrant_client): + """Test basic Qdrant processor initialization""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'store_uri': 'http://localhost:6333', + 'api_key': 'test-api-key', + 'taskgroup': AsyncMock(), + 'id': 'test-qdrant-processor' + } + + processor = Processor(**config) + + mock_qdrant_client.assert_called_once_with( + url='http://localhost:6333', api_key='test-api-key' + ) + assert hasattr(processor, 'qdrant') + assert processor.qdrant == mock_qdrant_instance + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_processor_initialization_with_defaults(self, mock_qdrant_client): + """Test processor initialization with default values""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-qdrant-processor' + } + + processor = Processor(**config) + + mock_qdrant_client.assert_called_once_with( + url='http://localhost:6333', api_key=None + ) + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_sanitize_name(self, mock_qdrant_client): + """Test name sanitization for Qdrant collections""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_client.return_value = MagicMock() + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Test basic sanitization + assert processor.sanitize_name("simple") == "simple" + assert processor.sanitize_name("with-dash") == "with_dash" + assert processor.sanitize_name("with.dot") == "with_dot" + assert processor.sanitize_name("UPPERCASE") == "uppercase" + + # Test numeric prefix handling + assert processor.sanitize_name("123start") == "r_123start" + assert processor.sanitize_name("_underscore") == "r__underscore" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_get_collection_name(self, mock_qdrant_client): + """Test Qdrant collection name generation""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_client.return_value = MagicMock() + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + collection_name = processor.get_collection_name( + user="test_user", + collection="test_collection", + schema_name="customer_data", + dimension=384 + ) + + assert collection_name == "rows_test_user_test_collection_customer_data_384" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_creates_new(self, mock_qdrant_client): + """Test that ensure_collection creates a new collection when needed""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = False + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + processor.ensure_collection("test_collection", 384) + + mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") + mock_qdrant_instance.create_collection.assert_called_once() + + # Verify the collection is cached + assert "test_collection" in processor.created_collections + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_skips_existing(self, mock_qdrant_client): + """Test that ensure_collection skips creation when collection exists""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + processor.ensure_collection("existing_collection", 384) + + mock_qdrant_instance.collection_exists.assert_called_once() + mock_qdrant_instance.create_collection.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_uses_cache(self, mock_qdrant_client): + """Test that ensure_collection uses cache for previously created collections""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.created_collections.add("cached_collection") + + processor.ensure_collection("cached_collection", 384) + + # Should not check or create - just return + mock_qdrant_instance.collection_exists.assert_not_called() + mock_qdrant_instance.create_collection.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') + async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client): + """Test processing basic row embeddings message""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = 'test-uuid-123' + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + # Create embeddings message + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + embedding = RowIndexEmbedding( + index_name='customer_id', + index_value=['CUST001'], + text='CUST001', + vectors=[[0.1, 0.2, 0.3]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='customers', + embeddings=[embedding] + ) + + # Mock message wrapper + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Verify upsert was called + mock_qdrant_instance.upsert.assert_called_once() + + # Verify upsert parameters + upsert_call_args = mock_qdrant_instance.upsert.call_args + assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3' + + point = upsert_call_args[1]['points'][0] + assert point.vector == [0.1, 0.2, 0.3] + assert point.payload['index_name'] == 'customer_id' + assert point.payload['index_value'] == ['CUST001'] + assert point.payload['text'] == 'CUST001' + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') + async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client): + """Test processing embeddings with multiple vectors""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = 'test-uuid' + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + # Embedding with multiple vectors + embedding = RowIndexEmbedding( + index_name='name', + index_value=['John Doe'], + text='John Doe', + vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='people', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should be called 3 times (once per vector) + assert mock_qdrant_instance.upsert.call_count == 3 + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client): + """Test that embeddings with no vectors are skipped""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + # Embedding with no vectors + embedding = RowIndexEmbedding( + index_name='id', + index_value=['123'], + text='123', + vectors=[] # Empty vectors + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='items', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should not call upsert for empty vectors + mock_qdrant_instance.upsert.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client): + """Test that messages for unknown collections are dropped""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + # No collections registered + + metadata = MagicMock() + metadata.user = 'unknown_user' + metadata.collection = 'unknown_collection' + metadata.id = 'doc-123' + + embedding = RowIndexEmbedding( + index_name='id', + index_value=['123'], + text='123', + vectors=[[0.1, 0.2]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='items', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should not call upsert for unknown collection + mock_qdrant_instance.upsert.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_delete_collection(self, mock_qdrant_client): + """Test deleting all collections for a user/collection""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + + # Mock collections list + mock_coll1 = MagicMock() + mock_coll1.name = 'rows_test_user_test_collection_schema1_384' + mock_coll2 = MagicMock() + mock_coll2.name = 'rows_test_user_test_collection_schema2_384' + mock_coll3 = MagicMock() + mock_coll3.name = 'rows_other_user_other_collection_schema_384' + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3] + mock_qdrant_instance.get_collections.return_value = mock_collections + + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.created_collections.add('rows_test_user_test_collection_schema1_384') + + await processor.delete_collection('test_user', 'test_collection') + + # Should delete only the matching collections + assert mock_qdrant_instance.delete_collection.call_count == 2 + + # Verify the cached collection was removed + assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_delete_collection_schema(self, mock_qdrant_client): + """Test deleting collections for a specific schema""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + + mock_coll1 = MagicMock() + mock_coll1.name = 'rows_test_user_test_collection_customers_384' + mock_coll2 = MagicMock() + mock_coll2.name = 'rows_test_user_test_collection_orders_384' + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll1, mock_coll2] + mock_qdrant_instance.get_collections.return_value = mock_collections + + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + await processor.delete_collection_schema( + 'test_user', 'test_collection', 'customers' + ) + + # Should only delete the customers schema collection + mock_qdrant_instance.delete_collection.assert_called_once() + call_args = mock_qdrant_instance.delete_collection.call_args[0] + assert call_args[0] == 'rows_test_user_test_collection_customers_384' + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py new file mode 100644 index 00000000..c8b81447 --- /dev/null +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -0,0 +1,474 @@ +""" +Unit tests for Cassandra Row Storage Processor (Unified Table Implementation) + +Tests the business logic of the row storage processor including: +- Schema configuration handling +- Name sanitization +- Unified table structure +- Index management +- Row storage with multi-index support +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +import json + +from trustgraph.storage.rows.cassandra.write import Processor +from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class TestRowsCassandraStorageLogic: + """Test business logic for unified table implementation""" + + def test_sanitize_name(self): + """Test name sanitization for Cassandra compatibility""" + processor = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + + # Test various name patterns + assert processor.sanitize_name("simple_name") == "simple_name" + assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes" + assert processor.sanitize_name("name.with.dots") == "name_with_dots" + assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number" + assert processor.sanitize_name("name with spaces") == "name_with_spaces" + assert processor.sanitize_name("special!@#$%^chars") == "special______chars" + assert processor.sanitize_name("UPPERCASE") == "uppercase" + assert processor.sanitize_name("CamelCase") == "camelcase" + assert processor.sanitize_name("_underscore_start") == "r__underscore_start" + + def test_get_index_names(self): + """Test extraction of index names from schema""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + # Schema with primary and indexed fields + schema = RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string"), # Not indexed + Field(name="status", type="string", indexed=True) + ] + ) + + index_names = processor.get_index_names(schema) + + # Should include primary key and indexed fields + assert "id" in index_names + assert "category" in index_names + assert "status" in index_names + assert "name" not in index_names # Not indexed + assert len(index_names) == 3 + + def test_get_index_names_no_indexes(self): + """Test schema with no indexed fields""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + schema = RowSchema( + name="no_index_schema", + fields=[ + Field(name="data1", type="string"), + Field(name="data2", type="string") + ] + ) + + index_names = processor.get_index_names(schema) + assert len(index_names) == 0 + + def test_build_index_value(self): + """Test building index values from row data""" + processor = MagicMock() + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + + value_map = {"id": "123", "category": "electronics", "name": "Widget"} + + # Single field index + result = processor.build_index_value(value_map, "id") + assert result == ["123"] + + result = processor.build_index_value(value_map, "category") + assert result == ["electronics"] + + # Missing field returns empty string + result = processor.build_index_value(value_map, "missing") + assert result == [""] + + def test_build_index_value_composite(self): + """Test building composite index values""" + processor = MagicMock() + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + + value_map = {"region": "us-west", "category": "electronics", "id": "123"} + + # Composite index (comma-separated field names) + result = processor.build_index_value(value_map, "region,category") + assert result == ["us-west", "electronics"] + + @pytest.mark.asyncio + async def test_schema_config_parsing(self): + """Test parsing of schema configurations""" + processor = MagicMock() + processor.schemas = {} + processor.config_key = "schema" + processor.registered_partitions = set() + processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) + + # Create test configuration + config = { + "schema": { + "customer_records": json.dumps({ + "name": "customer_records", + "description": "Customer data", + "fields": [ + { + "name": "id", + "type": "string", + "primary_key": True, + "required": True + }, + { + "name": "name", + "type": "string", + "required": True + }, + { + "name": "category", + "type": "string", + "indexed": True + } + ] + }) + } + } + + # Process configuration + await processor.on_schema_config(config, version=1) + + # Verify schema was loaded + assert "customer_records" in processor.schemas + schema = processor.schemas["customer_records"] + assert schema.name == "customer_records" + assert len(schema.fields) == 3 + + # Check field properties + id_field = schema.fields[0] + assert id_field.name == "id" + assert id_field.type == "string" + assert id_field.primary is True + + @pytest.mark.asyncio + async def test_object_processing_stores_data_map(self): + """Test that row processing stores data as map""" + processor = MagicMock() + processor.schemas = { + "test_schema": RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="value", type="string", size=100) + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create test object + test_obj = ExtractedObject( + metadata=Metadata( + id="test-001", + user="test_user", + collection="test_collection", + metadata=[] + ), + schema_name="test_schema", + values=[{"id": "123", "value": "test_data"}], + confidence=0.9, + source_span="test source" + ) + + # Create mock message + msg = MagicMock() + msg.value.return_value = test_obj + + # Process object + await processor.on_object(msg, None, None) + + # Verify insert was executed + processor.session.execute.assert_called() + insert_call = processor.session.execute.call_args + insert_cql = insert_call[0][0] + values = insert_call[0][1] + + # Verify using unified rows table + assert "INSERT INTO test_user.rows" in insert_cql + + # Values should be: (collection, schema_name, index_name, index_value, data, source) + assert values[0] == "test_collection" # collection + assert values[1] == "test_schema" # schema_name + assert values[2] == "id" # index_name (primary key field) + assert values[3] == ["123"] # index_value as list + assert values[4] == {"id": "123", "value": "test_data"} # data map + assert values[5] == "" # source + + @pytest.mark.asyncio + async def test_object_processing_multiple_indexes(self): + """Test that row is written once per indexed field""" + processor = MagicMock() + processor.schemas = { + "multi_index_schema": RowSchema( + name="multi_index_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="status", type="string", indexed=True) + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + test_obj = ExtractedObject( + metadata=Metadata( + id="test-001", + user="test_user", + collection="test_collection", + metadata=[] + ), + schema_name="multi_index_schema", + values=[{"id": "123", "category": "electronics", "status": "active"}], + confidence=0.9, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Should have 3 inserts (one per indexed field: id, category, status) + assert processor.session.execute.call_count == 3 + + # Check that different index_names were used + index_names_used = set() + for call in processor.session.execute.call_args_list: + values = call[0][1] + index_names_used.add(values[2]) # index_name is 3rd value + + assert index_names_used == {"id", "category", "status"} + + +class TestRowsCassandraStorageBatchLogic: + """Test batch processing logic for unified table implementation""" + + @pytest.mark.asyncio + async def test_batch_object_processing(self): + """Test processing of batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "batch_schema": RowSchema( + name="batch_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string") + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_collection", + metadata=[] + ), + schema_name="batch_schema", + values=[ + {"id": "001", "name": "First"}, + {"id": "002", "name": "Second"}, + {"id": "003", "name": "Third"} + ], + confidence=0.95, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = batch_obj + + await processor.on_object(msg, None, None) + + # Should have 3 inserts (one per row, one index per row since only primary key) + assert processor.session.execute.call_count == 3 + + # Check each insert has different id + ids_inserted = set() + for call in processor.session.execute.call_args_list: + values = call[0][1] + ids_inserted.add(tuple(values[3])) # index_value is 4th value + + assert ids_inserted == {("001",), ("002",), ("003",)} + + @pytest.mark.asyncio + async def test_empty_batch_processing(self): + """Test processing of empty batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", primary=True)] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create empty batch object + empty_batch_obj = ExtractedObject( + metadata=Metadata( + id="empty-001", + user="test_user", + collection="empty_collection", + metadata=[] + ), + schema_name="empty_schema", + values=[], # Empty batch + confidence=1.0, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = empty_batch_obj + + await processor.on_object(msg, None, None) + + # Verify no insert calls for empty batch + processor.session.execute.assert_not_called() + + +class TestUnifiedTableStructure: + """Test the unified rows table structure""" + + def test_ensure_tables_creates_unified_structure(self): + """Test that ensure_tables creates the unified rows table""" + processor = MagicMock() + processor.known_keyspaces = {"test_user"} + processor.tables_initialized = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.ensure_keyspace = MagicMock() + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + + processor.ensure_tables("test_user") + + # Should have 2 calls: create rows table + create row_partitions table + assert processor.session.execute.call_count == 2 + + # Check rows table creation + rows_cql = processor.session.execute.call_args_list[0][0][0] + assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql + assert "collection text" in rows_cql + assert "schema_name text" in rows_cql + assert "index_name text" in rows_cql + assert "index_value frozen>" in rows_cql + assert "data map" in rows_cql + assert "source text" in rows_cql + assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql + + # Check row_partitions table creation + partitions_cql = processor.session.execute.call_args_list[1][0][0] + assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql + assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql + + # Verify keyspace added to initialized set + assert "test_user" in processor.tables_initialized + + def test_ensure_tables_idempotent(self): + """Test that ensure_tables is idempotent""" + processor = MagicMock() + processor.tables_initialized = {"test_user"} # Already initialized + processor.session = MagicMock() + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + + processor.ensure_tables("test_user") + + # Should not execute any CQL since already initialized + processor.session.execute.assert_not_called() + + +class TestPartitionRegistration: + """Test partition registration for tracking what's stored""" + + def test_register_partitions(self): + """Test registering partitions for a collection/schema pair""" + processor = MagicMock() + processor.registered_partitions = set() + processor.session = MagicMock() + processor.schemas = { + "test_schema": RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True) + ] + ) + } + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + + processor.register_partitions("test_user", "test_collection", "test_schema") + + # Should have 2 inserts (one per index: id, category) + assert processor.session.execute.call_count == 2 + + # Verify cache was updated + assert ("test_collection", "test_schema") in processor.registered_partitions + + def test_register_partitions_idempotent(self): + """Test that partition registration is idempotent""" + processor = MagicMock() + processor.registered_partitions = {("test_collection", "test_schema")} # Already registered + processor.session = MagicMock() + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + + processor.register_partitions("test_user", "test_collection", "test_schema") + + # Should not execute any CQL since already registered + processor.session.execute.assert_not_called() diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 54ea1a95..73272942 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -6,7 +6,8 @@ import pytest from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.storage.triples.cassandra.write import Processor -from trustgraph.schema import Value, Triple +from trustgraph.schema import Triple, LITERAL, IRI +from trustgraph.direct.cassandra_kg import DEFAULT_GRAPH class TestCassandraStorageProcessor: @@ -86,29 +87,29 @@ class TestCassandraStorageProcessor: assert processor.cassandra_username == 'new-user' # Only cassandra_* params work @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_table_switching_with_auth(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_table_switching_with_auth(self, mock_kg_class): """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor( taskgroup=taskgroup_mock, cassandra_username='testuser', cassandra_password='testpass' ) - + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - + await processor.store_triples(mock_message) - + # Verify KnowledgeGraph was called with auth parameters - mock_trustgraph.assert_called_once_with( + mock_kg_class.assert_called_once_with( hosts=['cassandra'], # Updated default keyspace='user1', username='testuser', @@ -117,128 +118,150 @@ class TestCassandraStorageProcessor: assert processor.table == 'user1' @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_table_switching_without_auth(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_table_switching_without_auth(self, mock_kg_class): """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor(taskgroup=taskgroup_mock) - + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'user2' mock_message.metadata.collection = 'collection2' mock_message.triples = [] - + await processor.store_triples(mock_message) - + # Verify KnowledgeGraph was called without auth parameters - mock_trustgraph.assert_called_once_with( + mock_kg_class.assert_called_once_with( hosts=['cassandra'], # Updated default keyspace='user2' ) assert processor.table == 'user2' @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_table_reuse_when_same(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_table_reuse_when_same(self, mock_kg_class): """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor(taskgroup=taskgroup_mock) - + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - + # First call should create TrustGraph await processor.store_triples(mock_message) - assert mock_trustgraph.call_count == 1 - + assert mock_kg_class.call_count == 1 + # Second call with same table should reuse TrustGraph await processor.store_triples(mock_message) - assert mock_trustgraph.call_count == 1 # Should not increase + assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_triple_insertion(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_triple_insertion(self, mock_kg_class): """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor(taskgroup=taskgroup_mock) - - # Create mock triples + + # Create mock triples with proper Term structure triple1 = MagicMock() + triple1.s.type = LITERAL triple1.s.value = 'subject1' + triple1.s.datatype = '' + triple1.s.language = '' + triple1.p.type = LITERAL triple1.p.value = 'predicate1' + triple1.o.type = LITERAL triple1.o.value = 'object1' - + triple1.o.datatype = '' + triple1.o.language = '' + triple1.g = None + triple2 = MagicMock() + triple2.s.type = LITERAL triple2.s.value = 'subject2' + triple2.s.datatype = '' + triple2.s.language = '' + triple2.p.type = LITERAL triple2.p.value = 'predicate2' + triple2.o.type = LITERAL triple2.o.value = 'object2' - + triple2.o.datatype = '' + triple2.o.language = '' + triple2.g = None + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [triple1, triple2] - + await processor.store_triples(mock_message) - - # Verify both triples were inserted + + # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1') - mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2') + mock_tg_instance.insert.assert_any_call( + 'collection1', 'subject1', 'predicate1', 'object1', + g=DEFAULT_GRAPH, otype='l', dtype='', lang='' + ) + mock_tg_instance.insert.assert_any_call( + 'collection1', 'subject2', 'predicate2', 'object2', + g=DEFAULT_GRAPH, otype='l', dtype='', lang='' + ) @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_triple_insertion_with_empty_list(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_triple_insertion_with_empty_list(self, mock_kg_class): """Test behavior when message has no triples""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor(taskgroup=taskgroup_mock) - + # Create mock message with empty triples mock_message = MagicMock() mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - + await processor.store_triples(mock_message) - + # Verify no triples were inserted mock_tg_instance.insert.assert_not_called() @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.time.sleep') - async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph): + async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class): """Test exception handling during TrustGraph creation""" taskgroup_mock = MagicMock() - mock_trustgraph.side_effect = Exception("Connection failed") - + mock_kg_class.side_effect = Exception("Connection failed") + processor = Processor(taskgroup=taskgroup_mock) - + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - + with pytest.raises(Exception, match="Connection failed"): await processor.store_triples(mock_message) - + # Verify sleep was called before re-raising mock_sleep.assert_called_once_with(1) @@ -326,92 +349,104 @@ class TestCassandraStorageProcessor: mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n') @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_store_triples_table_switching_between_different_tables(self, mock_kg_class): """Test table switching when different tables are used in sequence""" taskgroup_mock = MagicMock() mock_tg_instance1 = MagicMock() mock_tg_instance2 = MagicMock() - mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2] - + mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2] + processor = Processor(taskgroup=taskgroup_mock) - + # First message with table1 mock_message1 = MagicMock() mock_message1.metadata.user = 'user1' mock_message1.metadata.collection = 'collection1' mock_message1.triples = [] - + await processor.store_triples(mock_message1) assert processor.table == 'user1' assert processor.tg == mock_tg_instance1 - + # Second message with different table mock_message2 = MagicMock() mock_message2.metadata.user = 'user2' mock_message2.metadata.collection = 'collection2' mock_message2.triples = [] - + await processor.store_triples(mock_message2) assert processor.table == 'user2' assert processor.tg == mock_tg_instance2 - + # Verify TrustGraph was created twice for different tables - assert mock_trustgraph.call_count == 2 + assert mock_kg_class.call_count == 2 @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_store_triples_with_special_characters_in_values(self, mock_kg_class): """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance - + mock_kg_class.return_value = mock_tg_instance + processor = Processor(taskgroup=taskgroup_mock) - - # Create triple with special characters + + # Create triple with special characters and proper Term structure triple = MagicMock() + triple.s.type = LITERAL triple.s.value = 'subject with spaces & symbols' + triple.s.datatype = '' + triple.s.language = '' + triple.p.type = LITERAL triple.p.value = 'predicate:with/colons' + triple.o.type = LITERAL triple.o.value = 'object with "quotes" and unicode: ñáéíóú' - + triple.o.datatype = '' + triple.o.language = '' + triple.g = None + mock_message = MagicMock() mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [triple] - + await processor.store_triples(mock_message) - + # Verify the triple was inserted with special characters preserved mock_tg_instance.insert.assert_called_once_with( 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', - 'object with "quotes" and unicode: ñáéíóú' + 'object with "quotes" and unicode: ñáéíóú', + g=DEFAULT_GRAPH, + otype='l', + dtype='', + lang='' ) @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): """Test that table remains unchanged when TrustGraph creation fails""" taskgroup_mock = MagicMock() - + processor = Processor(taskgroup=taskgroup_mock) - + # Set an initial table processor.table = ('old_user', 'old_collection') - + # Mock TrustGraph to raise exception - mock_trustgraph.side_effect = Exception("Connection failed") - + mock_kg_class.side_effect = Exception("Connection failed") + mock_message = MagicMock() mock_message.metadata.user = 'new_user' mock_message.metadata.collection = 'new_collection' mock_message.triples = [] - + with pytest.raises(Exception, match="Connection failed"): await processor.store_triples(mock_message) - + # Table should remain unchanged since self.table = table happens after try/except assert processor.table == ('old_user', 'old_collection') # TrustGraph should be set to None though @@ -422,12 +457,12 @@ class TestCassandraPerformanceOptimizations: """Test cases for multi-table performance optimizations""" @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_legacy_mode_uses_single_table(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_legacy_mode_uses_single_table(self, mock_kg_class): """Test that legacy mode still works with single table""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): processor = Processor(taskgroup=taskgroup_mock) @@ -440,16 +475,15 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples(mock_message) # Verify KnowledgeGraph instance uses legacy mode - kg_instance = mock_trustgraph.return_value - assert kg_instance is not None + assert mock_tg_instance is not None @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_optimized_mode_uses_multi_table(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_optimized_mode_uses_multi_table(self, mock_kg_class): """Test that optimized mode uses multi-table schema""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): processor = Processor(taskgroup=taskgroup_mock) @@ -462,24 +496,31 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples(mock_message) # Verify KnowledgeGraph instance is in optimized mode - kg_instance = mock_trustgraph.return_value - assert kg_instance is not None + assert mock_tg_instance is not None @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') - async def test_batch_write_consistency(self, mock_trustgraph): + @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') + async def test_batch_write_consistency(self, mock_kg_class): """Test that all tables stay consistent during batch writes""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() - mock_trustgraph.return_value = mock_tg_instance + mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) - # Create test triple + # Create test triple with proper Term structure triple = MagicMock() + triple.s.type = LITERAL triple.s.value = 'test_subject' + triple.s.datatype = '' + triple.s.language = '' + triple.p.type = LITERAL triple.p.value = 'test_predicate' + triple.o.type = LITERAL triple.o.value = 'test_object' + triple.o.datatype = '' + triple.o.language = '' + triple.g = None mock_message = MagicMock() mock_message.metadata.user = 'user1' @@ -490,7 +531,8 @@ class TestCassandraPerformanceOptimizations: # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) mock_tg_instance.insert.assert_called_once_with( - 'collection1', 'test_subject', 'test_predicate', 'test_object' + 'collection1', 'test_subject', 'test_predicate', 'test_object', + g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) def test_environment_variable_controls_mode(self): diff --git a/tests/unit/test_storage/test_triples_falkordb_storage.py b/tests/unit/test_storage/test_triples_falkordb_storage.py index 02d9cdd0..05dcb2e5 100644 --- a/tests/unit/test_storage/test_triples_falkordb_storage.py +++ b/tests/unit/test_storage/test_triples_falkordb_storage.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.storage.triples.falkordb.write import Processor -from trustgraph.schema import Value, Triple +from trustgraph.schema import Term, Triple, IRI, LITERAL class TestFalkorDBStorageProcessor: @@ -22,9 +22,9 @@ class TestFalkorDBStorageProcessor: # Create a test triple triple = Triple( - s=Value(value='http://example.com/subject', is_uri=True), - p=Value(value='http://example.com/predicate', is_uri=True), - o=Value(value='literal object', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject'), + p=Term(type=IRI, iri='http://example.com/predicate'), + o=Term(type=LITERAL, value='literal object') ) message.triples = [triple] @@ -183,9 +183,9 @@ class TestFalkorDBStorageProcessor: message.metadata.collection = 'test_collection' triple = Triple( - s=Value(value='http://example.com/subject', is_uri=True), - p=Value(value='http://example.com/predicate', is_uri=True), - o=Value(value='http://example.com/object', is_uri=True) + s=Term(type=IRI, iri='http://example.com/subject'), + p=Term(type=IRI, iri='http://example.com/predicate'), + o=Term(type=IRI, iri='http://example.com/object') ) message.triples = [triple] @@ -269,14 +269,14 @@ class TestFalkorDBStorageProcessor: message.metadata.collection = 'test_collection' triple1 = Triple( - s=Value(value='http://example.com/subject1', is_uri=True), - p=Value(value='http://example.com/predicate1', is_uri=True), - o=Value(value='literal object1', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject1'), + p=Term(type=IRI, iri='http://example.com/predicate1'), + o=Term(type=LITERAL, value='literal object1') ) triple2 = Triple( - s=Value(value='http://example.com/subject2', is_uri=True), - p=Value(value='http://example.com/predicate2', is_uri=True), - o=Value(value='http://example.com/object2', is_uri=True) + s=Term(type=IRI, iri='http://example.com/subject2'), + p=Term(type=IRI, iri='http://example.com/predicate2'), + o=Term(type=IRI, iri='http://example.com/object2') ) message.triples = [triple1, triple2] @@ -337,14 +337,14 @@ class TestFalkorDBStorageProcessor: message.metadata.collection = 'test_collection' triple1 = Triple( - s=Value(value='http://example.com/subject1', is_uri=True), - p=Value(value='http://example.com/predicate1', is_uri=True), - o=Value(value='literal object', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject1'), + p=Term(type=IRI, iri='http://example.com/predicate1'), + o=Term(type=LITERAL, value='literal object') ) triple2 = Triple( - s=Value(value='http://example.com/subject2', is_uri=True), - p=Value(value='http://example.com/predicate2', is_uri=True), - o=Value(value='http://example.com/object2', is_uri=True) + s=Term(type=IRI, iri='http://example.com/subject2'), + p=Term(type=IRI, iri='http://example.com/predicate2'), + o=Term(type=IRI, iri='http://example.com/object2') ) message.triples = [triple1, triple2] diff --git a/tests/unit/test_storage/test_triples_memgraph_storage.py b/tests/unit/test_storage/test_triples_memgraph_storage.py index b38f0759..162586d5 100644 --- a/tests/unit/test_storage/test_triples_memgraph_storage.py +++ b/tests/unit/test_storage/test_triples_memgraph_storage.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.storage.triples.memgraph.write import Processor -from trustgraph.schema import Value, Triple +from trustgraph.schema import Term, Triple, IRI, LITERAL class TestMemgraphStorageProcessor: @@ -22,9 +22,9 @@ class TestMemgraphStorageProcessor: # Create a test triple triple = Triple( - s=Value(value='http://example.com/subject', is_uri=True), - p=Value(value='http://example.com/predicate', is_uri=True), - o=Value(value='literal object', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject'), + p=Term(type=IRI, iri='http://example.com/predicate'), + o=Term(type=LITERAL, value='literal object') ) message.triples = [triple] @@ -231,9 +231,9 @@ class TestMemgraphStorageProcessor: mock_tx = MagicMock() triple = Triple( - s=Value(value='http://example.com/subject', is_uri=True), - p=Value(value='http://example.com/predicate', is_uri=True), - o=Value(value='http://example.com/object', is_uri=True) + s=Term(type=IRI, iri='http://example.com/subject'), + p=Term(type=IRI, iri='http://example.com/predicate'), + o=Term(type=IRI, iri='http://example.com/object') ) processor.create_triple(mock_tx, triple, "test_user", "test_collection") @@ -265,9 +265,9 @@ class TestMemgraphStorageProcessor: mock_tx = MagicMock() triple = Triple( - s=Value(value='http://example.com/subject', is_uri=True), - p=Value(value='http://example.com/predicate', is_uri=True), - o=Value(value='literal object', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject'), + p=Term(type=IRI, iri='http://example.com/predicate'), + o=Term(type=LITERAL, value='literal object') ) processor.create_triple(mock_tx, triple, "test_user", "test_collection") @@ -347,14 +347,14 @@ class TestMemgraphStorageProcessor: message.metadata.collection = 'test_collection' triple1 = Triple( - s=Value(value='http://example.com/subject1', is_uri=True), - p=Value(value='http://example.com/predicate1', is_uri=True), - o=Value(value='literal object1', is_uri=False) + s=Term(type=IRI, iri='http://example.com/subject1'), + p=Term(type=IRI, iri='http://example.com/predicate1'), + o=Term(type=LITERAL, value='literal object1') ) triple2 = Triple( - s=Value(value='http://example.com/subject2', is_uri=True), - p=Value(value='http://example.com/predicate2', is_uri=True), - o=Value(value='http://example.com/object2', is_uri=True) + s=Term(type=IRI, iri='http://example.com/subject2'), + p=Term(type=IRI, iri='http://example.com/predicate2'), + o=Term(type=IRI, iri='http://example.com/object2') ) message.triples = [triple1, triple2] diff --git a/tests/unit/test_storage/test_triples_neo4j_storage.py b/tests/unit/test_storage/test_triples_neo4j_storage.py index 2e307102..a5181ed9 100644 --- a/tests/unit/test_storage/test_triples_neo4j_storage.py +++ b/tests/unit/test_storage/test_triples_neo4j_storage.py @@ -6,6 +6,7 @@ import pytest from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.storage.triples.neo4j.write import Processor +from trustgraph.schema import IRI, LITERAL class TestNeo4jStorageProcessor: @@ -257,10 +258,12 @@ class TestNeo4jStorageProcessor: # Create mock triple with URI object triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" - triple.o.value = "http://example.com/object" - triple.o.is_uri = True + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = IRI + triple.o.iri = "http://example.com/object" # Create mock message with metadata mock_message = MagicMock() @@ -327,10 +330,12 @@ class TestNeo4jStorageProcessor: # Create mock triple with literal object triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = LITERAL triple.o.value = "literal value" - triple.o.is_uri = False # Create mock message with metadata mock_message = MagicMock() @@ -398,16 +403,20 @@ class TestNeo4jStorageProcessor: # Create mock triples triple1 = MagicMock() - triple1.s.value = "http://example.com/subject1" - triple1.p.value = "http://example.com/predicate1" - triple1.o.value = "http://example.com/object1" - triple1.o.is_uri = True - + triple1.s.type = IRI + triple1.s.iri = "http://example.com/subject1" + triple1.p.type = IRI + triple1.p.iri = "http://example.com/predicate1" + triple1.o.type = IRI + triple1.o.iri = "http://example.com/object1" + triple2 = MagicMock() - triple2.s.value = "http://example.com/subject2" - triple2.p.value = "http://example.com/predicate2" + triple2.s.type = IRI + triple2.s.iri = "http://example.com/subject2" + triple2.p.type = IRI + triple2.p.iri = "http://example.com/predicate2" + triple2.o.type = LITERAL triple2.o.value = "literal value" - triple2.o.is_uri = False # Create mock message with metadata mock_message = MagicMock() @@ -550,10 +559,12 @@ class TestNeo4jStorageProcessor: # Create triple with special characters triple = MagicMock() - triple.s.value = "http://example.com/subject with spaces" - triple.p.value = "http://example.com/predicate:with/symbols" + triple.s.type = IRI + triple.s.iri = "http://example.com/subject with spaces" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate:with/symbols" + triple.o.type = LITERAL triple.o.value = 'literal with "quotes" and unicode: ñáéíóú' - triple.o.is_uri = False mock_message = MagicMock() mock_message.triples = [triple] diff --git a/tests/unit/test_text_completion/test_googleaistudio_processor.py b/tests/unit/test_text_completion/test_googleaistudio_processor.py index c54b3928..aa04d2a3 100644 --- a/tests/unit/test_text_completion/test_googleaistudio_processor.py +++ b/tests/unit/test_text_completion/test_googleaistudio_processor.py @@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert hasattr(processor, 'client') assert hasattr(processor, 'safety_settings') assert len(processor.safety_settings) == 4 # 4 safety categories - mock_genai_class.assert_called_once_with(api_key='test-api-key') + mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert processor.default_model == 'gemini-1.5-pro' assert processor.temperature == 0.7 assert processor.max_output == 4096 - mock_genai_class.assert_called_once_with(api_key='custom-api-key') + mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert processor.default_model == 'gemini-2.0-flash-001' # default_model assert processor.temperature == 0.0 # default_temperature assert processor.max_output == 8192 # default_max_output - mock_genai_class.assert_called_once_with(api_key='test-api-key') + mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): # Assert # Verify Google AI Studio client was called with correct API key - mock_genai_class.assert_called_once_with(api_key='gai-test-key') + mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False) # Verify processor has the client assert processor.client == mock_genai_client diff --git a/tests/unit/test_text_completion/test_vertexai_processor.py b/tests/unit/test_text_completion/test_vertexai_processor.py index 60d61acd..cbc91cb5 100644 --- a/tests/unit/test_text_completion/test_vertexai_processor.py +++ b/tests/unit/test_text_completion/test_vertexai_processor.py @@ -1,6 +1,6 @@ """ Unit tests for trustgraph.model.text_completion.vertexai -Starting simple with one test to get the basics working +Updated for google-genai SDK """ import pytest @@ -15,19 +15,20 @@ from trustgraph.base import LlmResult class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): """Simple test for processor initialization""" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test basic processor initialization with mocked dependencies""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock the parent class initialization to avoid taskgroup requirement mock_async_init.return_value = None mock_llm_init.return_value = None @@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name' - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + assert processor.default_model == 'gemini-2.0-flash-001' + assert hasattr(processor, 'generation_configs') # Cache dictionary assert hasattr(processor, 'safety_settings') - assert hasattr(processor, 'model_clients') # LLM clients are now cached - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json') - mock_vertexai.init.assert_called_once() + assert hasattr(processor, 'client') # genai.Client + mock_service_account.Credentials.from_service_account_file.assert_called_once() + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials + ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test successful content generation""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Generated response from Gemini" mock_response.usage_metadata.prompt_token_count = 15 mock_response.usage_metadata.candidates_token_count = 8 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 15 assert result.out_token == 8 assert result.model == 'gemini-2.0-flash-001' - # Check that the method was called (actual prompt format may vary) - mock_model.generate_content.assert_called_once() - # Verify the call was made with the expected parameters - call_args = mock_model.generate_content.call_args - # Generation config is now created dynamically per model - assert 'generation_config' in call_args[1] - assert call_args[1]['safety_settings'] == processor.safety_settings + mock_client.models.generate_content.assert_called_once() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test rate limit error handling""" # Arrange from google.api_core.exceptions import ResourceExhausted from trustgraph.exceptions import TooManyRequests - + mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(TooManyRequests): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of blocked content (safety filters)""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = None # Blocked content returns None mock_response.usage_metadata.prompt_token_count = 10 mock_response.usage_metadata.candidates_token_count = 0 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.model == 'gemini-2.0-flash-001' @patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default): + async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default): """Test processor initialization without private key (uses default credentials)""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + # Mock google.auth.default() to return credentials and project ID mock_credentials = MagicMock() mock_auth_default.return_value = (mock_credentials, "test-project-123") - - # Mock GenerativeModel - mock_model = MagicMock() - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client config = { 'region': 'us-central1', @@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'gemini-2.0-flash-001' mock_auth_default.assert_called_once() - mock_vertexai.init.assert_called_once_with( - location='us-central1', - project='test-project-123' + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of generic exceptions""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = Exception("Network error") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = Exception("Network error") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(Exception, match="Network error"): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test processor initialization with custom parameters""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert processor.default_model == 'gemini-1.5-pro' - - # Verify that generation_config object exists (can't easily check internal values) - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + + # Verify that generation_config cache exists + assert hasattr(processor, 'generation_configs') assert processor.generation_configs == {} # Empty cache initially - + # Verify that safety settings are configured assert len(processor.safety_settings) == 4 - + # Verify service account was called with custom key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json') - - # Verify that api_params dict has the correct values (this is accessible) + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify that api_params dict has the correct values assert processor.api_params["temperature"] == 0.7 assert processor.api_params["max_output_tokens"] == 4096 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test that VertexAI is initialized correctly with credentials""" # Arrange mock_credentials = MagicMock() mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - # Verify VertexAI init was called with correct parameters - mock_vertexai.init.assert_called_once_with( + # Verify genai.Client was called with correct parameters + mock_genai.Client.assert_called_once_with( + vertexai=True, + project='test-project-123', location='europe-west1', - credentials=mock_credentials, - project='test-project-123' + credentials=mock_credentials ) - - # GenerativeModel is now created lazily on first use, not at initialization - mock_generative_model.assert_not_called() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test content generation with empty prompts""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Default response" mock_response.usage_metadata.prompt_token_count = 2 mock_response.usage_metadata.candidates_token_count = 3 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 2 assert result.out_token == 3 assert result.model == 'gemini-2.0-flash-001' - - # Verify the model was called with the combined empty prompts - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - # The prompt should be "" + "\n\n" + "" = "\n\n" - assert call_args[0][0] == "\n\n" + + # Verify the client was called + mock_client.models.generate_content.assert_called_once() @patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex): + async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex): """Test Anthropic processor initialization with private key credentials""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + mock_credentials = MagicMock() mock_credentials.project_id = "test-project-456" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock AnthropicVertex mock_anthropic_client = MagicMock() mock_anthropic_vertex.return_value = mock_anthropic_client @@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'claude-3-sonnet@20240229' - # is_anthropic logic is now determined dynamically per request - + # Verify service account was called with private key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json') - - # Verify AnthropicVertex was initialized with credentials + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify AnthropicVertex was initialized with credentials (because model contains 'claude') mock_anthropic_vertex.assert_called_once_with( region='us-west1', project_id='test-project-456', credentials=mock_credentials ) - + # Verify api_params are set correctly assert processor.api_params["temperature"] == 0.5 assert processor.api_params["max_output_tokens"] == 2048 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test temperature parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom temperature" mock_response.usage_metadata.prompt_token_count = 20 mock_response.usage_metadata.candidates_token_count = 12 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with custom temperature" + mock_client.models.generate_content.assert_called_once() - # Verify Gemini API was called with overridden temperature - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Check that generation_config was created (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration - + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test model parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - # Mock different models - mock_model_default = MagicMock() - mock_model_override = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom model" mock_response.usage_metadata.prompt_token_count = 18 mock_response.usage_metadata.candidates_token_count = 14 - mock_model_override.generate_content.return_value = mock_response - # GenerativeModel should return different models based on input - def model_factory(model_name): - if model_name == 'gemini-1.5-pro': - return mock_model_override - return mock_model_default - - mock_generative_model.side_effect = model_factory + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): config = { 'region': 'us-central1', 'model': 'gemini-2.0-flash-001', # Default model - 'temperature': 0.2, # Default temperature + 'temperature': 0.2, 'max_output': 8192, 'private_key': 'private.json', 'concurrency': 1, @@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert isinstance(result, LlmResult) assert result.text == "Response with custom model" - # Verify the overridden model was used - mock_model_override.generate_content.assert_called_once() - # Verify GenerativeModel was called with the override model - mock_generative_model.assert_called_with('gemini-1.5-pro') + # Verify the call was made with the override model + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-pro" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test overriding both model and temperature parameters simultaneously""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with both overrides" mock_response.usage_metadata.prompt_token_count = 22 mock_response.usage_metadata.candidates_token_count = 16 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with both overrides" + mock_client.models.generate_content.assert_called_once() - # Verify both overrides were used - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Verify model override - mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override - - # Verify temperature override (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration + # Verify the model override was used + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-flash-001" if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index 93466cd2..daa2cc5c 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -73,6 +73,8 @@ from .async_metrics import AsyncMetrics # Types from .types import ( Triple, + Uri, + Literal, ConfigKey, ConfigValue, DocumentMetadata, @@ -99,7 +101,7 @@ from .exceptions import ( LoadError, LookupError, NLPQueryError, - ObjectsQueryError, + RowsQueryError, RequestError, StructuredQueryError, UnexpectedError, @@ -133,6 +135,8 @@ __all__ = [ # Types "Triple", + "Uri", + "Literal", "ConfigKey", "ConfigValue", "DocumentMetadata", @@ -157,7 +161,7 @@ __all__ = [ "LoadError", "LookupError", "NLPQueryError", - "ObjectsQueryError", + "RowsQueryError", "RequestError", "StructuredQueryError", "UnexpectedError", diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 76cb9f56..9a6a49c3 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -115,15 +115,15 @@ class AsyncBulkClient: async for raw_message in websocket: yield json.loads(raw_message) - async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: - """Bulk import objects via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import rows via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" if self.token: ws_url = f"{ws_url}?token={self.token}" async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - async for obj in objects: - await websocket.send(json.dumps(obj)) + async for row in rows: + await websocket.send(json.dumps(row)) async def aclose(self) -> None: """Close connections""" diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 6b28886b..440cebae 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -612,8 +612,12 @@ class AsyncFlowInstance: print(f"{entity['name']}: {entity['score']}") ``` """ + # First convert text to embeddings vectors + emb_result = await self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + request_data = { - "text": text, + "vectors": vectors, "user": user, "collection": collection, "limit": limit @@ -704,18 +708,18 @@ class AsyncFlowInstance: return await self.request("triples", request_data) - async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, - operation_name: Optional[str] = None, **kwargs: Any): + async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs: Any): """ - Execute a GraphQL query on stored objects. + Execute a GraphQL query on stored rows. - Queries structured data objects using GraphQL syntax. Supports complex + Queries structured data rows using GraphQL syntax. Supports complex queries with variables and named operations. Args: query: GraphQL query string user: User identifier - collection: Collection identifier containing objects + collection: Collection identifier containing rows variables: Optional GraphQL query variables operation_name: Optional operation name for multi-operation queries **kwargs: Additional service-specific parameters @@ -739,7 +743,7 @@ class AsyncFlowInstance: } ''' - result = await flow.objects_query( + result = await flow.rows_query( query=query, user="trustgraph", collection="users", @@ -761,4 +765,64 @@ class AsyncFlowInstance: request_data["operationName"] = operation_name request_data.update(kwargs) - return await self.request("objects", request_data) + return await self.request("rows", request_data) + + async def row_embeddings_query( + self, text: str, schema_name: str, user: str = "trustgraph", + collection: str = "default", index_name: Optional[str] = None, + limit: int = 10, **kwargs: Any + ): + """ + Query row embeddings for semantic search on structured data. + + Performs semantic search over row index embeddings to find rows whose + indexed field values are most similar to the input text. Enables + fuzzy/semantic matching on structured data. + + Args: + text: Query text for semantic search + schema_name: Schema name to search within + user: User identifier (default: "trustgraph") + collection: Collection identifier (default: "default") + index_name: Optional index name to filter search to specific index + limit: Maximum number of results to return (default: 10) + **kwargs: Additional service-specific parameters + + Returns: + dict: Response containing matches with index_name, index_value, + text, and score + + Example: + ```python + async_flow = await api.async_flow() + flow = async_flow.id("default") + + # Search for customers by name similarity + results = await flow.row_embeddings_query( + text="John Smith", + schema_name="customers", + user="trustgraph", + collection="sales", + limit=5 + ) + + for match in results.get("matches", []): + print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") + ``` + """ + # First convert text to embeddings vectors + emb_result = await self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + request_data = { + "vectors": vectors, + "schema_name": schema_name, + "user": user, + "collection": collection, + "limit": limit + } + if index_name: + request_data["index_name"] = index_name + request_data.update(kwargs) + + return await self.request("row-embeddings", request_data) diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index cb6c8605..3241e0f7 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -282,8 +282,12 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" + # First convert text to embeddings vectors + emb_result = await self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + request = { - "text": text, + "vectors": vectors, "user": user, "collection": collection, "limit": limit @@ -316,9 +320,9 @@ class AsyncSocketFlowInstance: return await self.client._send_request("triples", self.flow_id, request) - async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, - operation_name: Optional[str] = None, **kwargs): - """GraphQL query""" + async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs): + """GraphQL query against structured rows""" request = { "query": query, "user": user, @@ -330,7 +334,7 @@ class AsyncSocketFlowInstance: request["operationName"] = operation_name request.update(kwargs) - return await self.client._send_request("objects", self.flow_id, request) + return await self.client._send_request("rows", self.flow_id, request) async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs): """Execute MCP tool""" @@ -341,3 +345,26 @@ class AsyncSocketFlowInstance: request.update(kwargs) return await self.client._send_request("mcp-tool", self.flow_id, request) + + async def row_embeddings_query( + self, text: str, schema_name: str, user: str = "trustgraph", + collection: str = "default", index_name: Optional[str] = None, + limit: int = 10, **kwargs + ): + """Query row embeddings for semantic search on structured data""" + # First convert text to embeddings vectors + emb_result = await self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + request = { + "vectors": vectors, + "schema_name": schema_name, + "user": user, + "collection": collection, + "limit": limit + } + if index_name: + request["index_name"] = index_name + request.update(kwargs) + + return await self.client._send_request("row-embeddings", self.flow_id, request) diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index a2796332..3dfb0fba 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -15,6 +15,15 @@ from . types import Triple from . exceptions import ProtocolException +def _string_to_term(value: str) -> Dict[str, Any]: + """Convert a string value to Term format for the gateway.""" + # Treat URIs as IRI type, otherwise as literal + if value.startswith("http://") or value.startswith("https://") or "://" in value: + return {"t": "i", "i": value} + else: + return {"t": "l", "v": value} + + class BulkClient: """ Synchronous bulk operations client for import/export. @@ -62,7 +71,12 @@ class BulkClient: return loop.run_until_complete(coro) - def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None: + def import_triples( + self, flow: str, triples: Iterator[Triple], + metadata: Optional[Dict[str, Any]] = None, + batch_size: int = 100, + **kwargs: Any + ) -> None: """ Bulk import RDF triples into a flow. @@ -71,6 +85,8 @@ class BulkClient: Args: flow: Flow identifier triples: Iterator yielding Triple objects + metadata: Metadata dict with id, metadata, user, collection + batch_size: Number of triples per batch (default 100) **kwargs: Additional parameters (reserved for future use) Example: @@ -86,23 +102,47 @@ class BulkClient: # ... more triples # Import triples - bulk.import_triples(flow="default", triples=triple_generator()) + bulk.import_triples( + flow="default", + triples=triple_generator(), + metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} + ) ``` """ - self._run_async(self._import_triples_async(flow, triples)) + self._run_async(self._import_triples_async(flow, triples, metadata, batch_size)) - async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None: + async def _import_triples_async( + self, flow: str, triples: Iterator[Triple], + metadata: Optional[Dict[str, Any]], batch_size: int + ) -> None: """Async implementation of triple import""" ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" if self.token: ws_url = f"{ws_url}?token={self.token}" + if metadata is None: + metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + batch = [] for triple in triples: + batch.append({ + "s": _string_to_term(triple.s), + "p": _string_to_term(triple.p), + "o": _string_to_term(triple.o) + }) + if len(batch) >= batch_size: + message = { + "metadata": metadata, + "triples": batch + } + await websocket.send(json.dumps(message)) + batch = [] + # Send remaining items + if batch: message = { - "s": triple.s, - "p": triple.p, - "o": triple.o + "metadata": metadata, + "triples": batch } await websocket.send(json.dumps(message)) @@ -362,7 +402,12 @@ class BulkClient: async for raw_message in websocket: yield json.loads(raw_message) - def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + def import_entity_contexts( + self, flow: str, contexts: Iterator[Dict[str, Any]], + metadata: Optional[Dict[str, Any]] = None, + batch_size: int = 100, + **kwargs: Any + ) -> None: """ Bulk import entity contexts into a flow. @@ -373,6 +418,8 @@ class BulkClient: Args: flow: Flow identifier contexts: Iterator yielding context dictionaries + metadata: Metadata dict with id, metadata, user, collection + batch_size: Number of contexts per batch (default 100) **kwargs: Additional parameters (reserved for future use) Example: @@ -381,27 +428,49 @@ class BulkClient: # Generate entity contexts to import def context_generator(): - yield {"entity": "entity1", "context": "Description of entity1..."} - yield {"entity": "entity2", "context": "Description of entity2..."} + yield {"entity": {"v": "entity1", "e": True}, "context": "Description..."} + yield {"entity": {"v": "entity2", "e": True}, "context": "Description..."} # ... more contexts bulk.import_entity_contexts( flow="default", - contexts=context_generator() + contexts=context_generator(), + metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} ) ``` """ - self._run_async(self._import_entity_contexts_async(flow, contexts)) + self._run_async(self._import_entity_contexts_async(flow, contexts, metadata, batch_size)) - async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None: + async def _import_entity_contexts_async( + self, flow: str, contexts: Iterator[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], batch_size: int + ) -> None: """Async implementation of entity contexts import""" ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" if self.token: ws_url = f"{ws_url}?token={self.token}" + if metadata is None: + metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + batch = [] for context in contexts: - await websocket.send(json.dumps(context)) + batch.append(context) + if len(batch) >= batch_size: + message = { + "metadata": metadata, + "entities": batch + } + await websocket.send(json.dumps(message)) + batch = [] + # Send remaining items + if batch: + message = { + "metadata": metadata, + "entities": batch + } + await websocket.send(json.dumps(message)) def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]: """ @@ -461,45 +530,45 @@ class BulkClient: async for raw_message in websocket: yield json.loads(raw_message) - def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None: """ - Bulk import structured objects into a flow. + Bulk import structured rows into a flow. - Efficiently uploads structured data objects via WebSocket streaming + Efficiently uploads structured data rows via WebSocket streaming for use in GraphQL queries. Args: flow: Flow identifier - objects: Iterator yielding object dictionaries + rows: Iterator yielding row dictionaries **kwargs: Additional parameters (reserved for future use) Example: ```python bulk = api.bulk() - # Generate objects to import - def object_generator(): - yield {"id": "obj1", "name": "Object 1", "value": 100} - yield {"id": "obj2", "name": "Object 2", "value": 200} - # ... more objects + # Generate rows to import + def row_generator(): + yield {"id": "row1", "name": "Row 1", "value": 100} + yield {"id": "row2", "name": "Row 2", "value": 200} + # ... more rows - bulk.import_objects( + bulk.import_rows( flow="default", - objects=object_generator() + rows=row_generator() ) ``` """ - self._run_async(self._import_objects_async(flow, objects)) + self._run_async(self._import_rows_async(flow, rows)) - async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None: - """Async implementation of objects import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: + """Async implementation of rows import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" if self.token: ws_url = f"{ws_url}?token={self.token}" async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - for obj in objects: - await websocket.send(json.dumps(obj)) + for row in rows: + await websocket.send(json.dumps(row)) def close(self) -> None: """Close connections""" diff --git a/trustgraph-base/trustgraph/api/exceptions.py b/trustgraph-base/trustgraph/api/exceptions.py index 311d2651..b60e41e1 100644 --- a/trustgraph-base/trustgraph/api/exceptions.py +++ b/trustgraph-base/trustgraph/api/exceptions.py @@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException): pass -class ObjectsQueryError(TrustGraphException): - """Objects query service error""" +class RowsQueryError(TrustGraphException): + """Rows query service error""" pass @@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = { "load-error": LoadError, "lookup-error": LookupError, "nlp-query-error": NLPQueryError, - "objects-query-error": ObjectsQueryError, + "rows-query-error": RowsQueryError, "request-error": RequestError, "structured-query-error": StructuredQueryError, "unexpected-error": UnexpectedError, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index d06a6327..cc07f794 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -10,12 +10,27 @@ import json import base64 from .. knowledge import hash, Uri, Literal +from .. schema import IRI, LITERAL from . types import Triple from . exceptions import ProtocolException + def to_value(x): - if x["e"]: return Uri(x["v"]) - return Literal(x["v"]) + """Convert wire format to Uri or Literal.""" + if x.get("t") == IRI: + return Uri(x.get("i", "")) + elif x.get("t") == LITERAL: + return Literal(x.get("v", "")) + # Fallback for any other type + return Literal(x.get("v", x.get("i", ""))) + + +def from_value(v): + """Convert Uri or Literal to wire format.""" + if isinstance(v, Uri): + return {"t": IRI, "i": str(v)} + else: + return {"t": LITERAL, "v": str(v)} class Flow: """ @@ -569,9 +584,13 @@ class FlowInstance: ``` """ + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + # Query graph embeddings for semantic search input = { - "text": text, + "vectors": vectors, "user": user, "collection": collection, "limit": limit @@ -582,6 +601,51 @@ class FlowInstance: input ) + def document_embeddings_query(self, text, user, collection, limit=10): + """ + Query document chunks using semantic similarity. + + Finds document chunks whose content is semantically similar to the + input text, using vector embeddings. + + Args: + text: Query text for semantic search + user: User/keyspace identifier + collection: Collection identifier + limit: Maximum number of results (default: 10) + + Returns: + dict: Query results with similar document chunks + + Example: + ```python + flow = api.flow().id("default") + results = flow.document_embeddings_query( + text="machine learning algorithms", + user="trustgraph", + collection="research-papers", + limit=5 + ) + ``` + """ + + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + # Query document embeddings for semantic search + input = { + "vectors": vectors, + "user": user, + "collection": collection, + "limit": limit + } + + return self.request( + "service/document-embeddings", + input + ) + def prompt(self, id, variables): """ Execute a prompt template with variable substitution. @@ -751,17 +815,17 @@ class FlowInstance: if s: if not isinstance(s, Uri): raise RuntimeError("s must be Uri") - input["s"] = { "v": str(s), "e": isinstance(s, Uri), } - + input["s"] = from_value(s) + if p: if not isinstance(p, Uri): raise RuntimeError("p must be Uri") - input["p"] = { "v": str(p), "e": isinstance(p, Uri), } + input["p"] = from_value(p) if o: if not isinstance(o, Uri) and not isinstance(o, Literal): raise RuntimeError("o must be Uri or Literal") - input["o"] = { "v": str(o), "e": isinstance(o, Uri), } + input["o"] = from_value(o) object = self.request( "service/triples", @@ -834,9 +898,9 @@ class FlowInstance: if metadata: metadata.emit( lambda t: triples.append({ - "s": { "v": t["s"], "e": isinstance(t["s"], Uri) }, - "p": { "v": t["p"], "e": isinstance(t["p"], Uri) }, - "o": { "v": t["o"], "e": isinstance(t["o"], Uri) } + "s": from_value(t["s"]), + "p": from_value(t["p"]), + "o": from_value(t["o"]), }) ) @@ -913,9 +977,9 @@ class FlowInstance: if metadata: metadata.emit( lambda t: triples.append({ - "s": { "v": t["s"], "e": isinstance(t["s"], Uri) }, - "p": { "v": t["p"], "e": isinstance(t["p"], Uri) }, - "o": { "v": t["o"], "e": isinstance(t["o"], Uri) } + "s": from_value(t["s"]), + "p": from_value(t["p"]), + "o": from_value(t["o"]), }) ) @@ -937,12 +1001,12 @@ class FlowInstance: input ) - def objects_query( + def rows_query( self, query, user="trustgraph", collection="default", variables=None, operation_name=None ): """ - Execute a GraphQL query against structured objects in the knowledge graph. + Execute a GraphQL query against structured rows in the knowledge graph. Queries structured data using GraphQL syntax, allowing complex queries with filtering, aggregation, and relationship traversal. @@ -974,7 +1038,7 @@ class FlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, user="trustgraph", collection="scientists" @@ -989,7 +1053,7 @@ class FlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, variables={"name": "Marie Curie"} ) @@ -1010,7 +1074,7 @@ class FlowInstance: input["operation_name"] = operation_name response = self.request( - "service/objects", + "service/rows", input ) @@ -1233,3 +1297,78 @@ class FlowInstance: return response["schema-matches"] + def row_embeddings_query( + self, text, schema_name, user="trustgraph", collection="default", + index_name=None, limit=10 + ): + """ + Query row data using semantic similarity on indexed fields. + + Finds rows whose indexed field values are semantically similar to the + input text, using vector embeddings. This enables fuzzy/semantic matching + on structured data. + + Args: + text: Query text for semantic search + schema_name: Schema name to search within + user: User/keyspace identifier (default: "trustgraph") + collection: Collection identifier (default: "default") + index_name: Optional index name to filter search to specific index + limit: Maximum number of results (default: 10) + + Returns: + dict: Query results with matches containing index_name, index_value, + text, and score + + Example: + ```python + flow = api.flow().id("default") + + # Search for customers by name similarity + results = flow.row_embeddings_query( + text="John Smith", + schema_name="customers", + user="trustgraph", + collection="sales", + limit=5 + ) + + # Filter to specific index + results = flow.row_embeddings_query( + text="machine learning engineer", + schema_name="employees", + index_name="job_title", + limit=10 + ) + ``` + """ + + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + # Query row embeddings for semantic search + input = { + "vectors": vectors, + "schema_name": schema_name, + "user": user, + "collection": collection, + "limit": limit + } + + if index_name: + input["index_name"] = index_name + + response = self.request( + "service/row-embeddings", + input + ) + + # Check for system-level error + if "error" in response and response["error"]: + error_type = response["error"].get("type", "unknown") + error_message = response["error"].get("message", "Unknown error") + raise ProtocolException(f"{error_type}: {error_message}") + + return response + diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index 23f6c9f2..1fae350c 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -10,11 +10,18 @@ import json import base64 from .. knowledge import hash, Uri, Literal +from .. schema import IRI, LITERAL from . types import Triple + def to_value(x): - if x["e"]: return Uri(x["v"]) - return Literal(x["v"]) + """Convert wire format to Uri or Literal.""" + if x.get("t") == IRI: + return Uri(x.get("i", "")) + elif x.get("t") == LITERAL: + return Literal(x.get("v", "")) + # Fallback for any other type + return Literal(x.get("v", x.get("i", ""))) class Knowledge: """ diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index b068f627..e50dc0aa 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -12,13 +12,28 @@ import logging from . types import DocumentMetadata, ProcessingMetadata, Triple from .. knowledge import hash, Uri, Literal +from .. schema import IRI, LITERAL from . exceptions import * logger = logging.getLogger(__name__) + def to_value(x): - if x["e"]: return Uri(x["v"]) - return Literal(x["v"]) + """Convert wire format to Uri or Literal.""" + if x.get("t") == IRI: + return Uri(x.get("i", "")) + elif x.get("t") == LITERAL: + return Literal(x.get("v", "")) + # Fallback for any other type + return Literal(x.get("v", x.get("i", ""))) + + +def from_value(v): + """Convert Uri or Literal to wire format.""" + if isinstance(v, Uri): + return {"t": IRI, "i": str(v)} + else: + return {"t": LITERAL, "v": str(v)} class Library: """ @@ -118,18 +133,18 @@ class Library: if isinstance(metadata, list): triples = [ { - "s": { "v": t.s, "e": isinstance(t.s, Uri) }, - "p": { "v": t.p, "e": isinstance(t.p, Uri) }, - "o": { "v": t.o, "e": isinstance(t.o, Uri) } + "s": from_value(t.s), + "p": from_value(t.p), + "o": from_value(t.o), } for t in metadata ] elif hasattr(metadata, "emit"): metadata.emit( lambda t: triples.append({ - "s": { "v": t["s"], "e": isinstance(t["s"], Uri) }, - "p": { "v": t["p"], "e": isinstance(t["p"], Uri) }, - "o": { "v": t["o"], "e": isinstance(t["o"], Uri) } + "s": from_value(t["s"]), + "p": from_value(t["p"]), + "o": from_value(t["o"]), }) ) else: @@ -315,9 +330,9 @@ class Library: "comments": metadata.comments, "metadata": [ { - "s": { "v": t["s"], "e": isinstance(t["s"], Uri) }, - "p": { "v": t["p"], "e": isinstance(t["p"], Uri) }, - "o": { "v": t["o"], "e": isinstance(t["o"], Uri) } + "s": from_value(t["s"]), + "p": from_value(t["p"]), + "o": from_value(t["o"]), } for t in metadata.metadata ], diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index c712f808..e8de442a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -649,8 +649,12 @@ class SocketFlowInstance: ) ``` """ + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + request = { - "text": text, + "vectors": vectors, "user": user, "collection": collection, "limit": limit @@ -659,6 +663,54 @@ class SocketFlowInstance: return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False) + def document_embeddings_query( + self, + text: str, + user: str, + collection: str, + limit: int = 10, + **kwargs: Any + ) -> Dict[str, Any]: + """ + Query document chunks using semantic similarity. + + Args: + text: Query text for semantic search + user: User/keyspace identifier + collection: Collection identifier + limit: Maximum number of results (default: 10) + **kwargs: Additional parameters passed to the service + + Returns: + dict: Query results with similar document chunks + + Example: + ```python + socket = api.socket() + flow = socket.flow("default") + + results = flow.document_embeddings_query( + text="machine learning algorithms", + user="trustgraph", + collection="research-papers", + limit=5 + ) + ``` + """ + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + request = { + "vectors": vectors, + "user": user, + "collection": collection, + "limit": limit + } + request.update(kwargs) + + return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) + def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: """ Generate vector embeddings for text. @@ -737,7 +789,7 @@ class SocketFlowInstance: return self.client._send_request_sync("triples", self.flow_id, request, False) - def objects_query( + def rows_query( self, query: str, user: str, @@ -747,7 +799,7 @@ class SocketFlowInstance: **kwargs: Any ) -> Dict[str, Any]: """ - Execute a GraphQL query against structured objects. + Execute a GraphQL query against structured rows. Args: query: GraphQL query string @@ -774,7 +826,7 @@ class SocketFlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, user="trustgraph", collection="scientists" @@ -792,7 +844,7 @@ class SocketFlowInstance: request["operationName"] = operation_name request.update(kwargs) - return self.client._send_request_sync("objects", self.flow_id, request, False) + return self.client._send_request_sync("rows", self.flow_id, request, False) def mcp_tool( self, @@ -829,3 +881,73 @@ class SocketFlowInstance: request.update(kwargs) return self.client._send_request_sync("mcp-tool", self.flow_id, request, False) + + def row_embeddings_query( + self, + text: str, + schema_name: str, + user: str = "trustgraph", + collection: str = "default", + index_name: Optional[str] = None, + limit: int = 10, + **kwargs: Any + ) -> Dict[str, Any]: + """ + Query row data using semantic similarity on indexed fields. + + Finds rows whose indexed field values are semantically similar to the + input text, using vector embeddings. This enables fuzzy/semantic matching + on structured data. + + Args: + text: Query text for semantic search + schema_name: Schema name to search within + user: User/keyspace identifier (default: "trustgraph") + collection: Collection identifier (default: "default") + index_name: Optional index name to filter search to specific index + limit: Maximum number of results (default: 10) + **kwargs: Additional parameters passed to the service + + Returns: + dict: Query results with matches containing index_name, index_value, + text, and score + + Example: + ```python + socket = api.socket() + flow = socket.flow("default") + + # Search for customers by name similarity + results = flow.row_embeddings_query( + text="John Smith", + schema_name="customers", + user="trustgraph", + collection="sales", + limit=5 + ) + + # Filter to specific index + results = flow.row_embeddings_query( + text="machine learning engineer", + schema_name="employees", + index_name="job_title", + limit=10 + ) + ``` + """ + # First convert text to embeddings vectors + emb_result = self.embeddings(text=text) + vectors = emb_result.get("vectors", []) + + request = { + "vectors": vectors, + "schema_name": schema_name, + "user": user, + "collection": collection, + "limit": limit + } + if index_name: + request["index_name"] = index_name + request.update(kwargs) + + return self.client._send_request_sync("row-embeddings", self.flow_id, request, False) diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index e8530f6c..557109a2 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -34,5 +34,6 @@ from . tool_service import ToolService from . tool_client import ToolClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec +from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec from . collection_config_handler import CollectionConfigHandler diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index bca915e0..f04f2c60 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -7,7 +7,7 @@ embeddings. import logging from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse -from .. schema import Error, Value +from .. schema import Error, Term from . flow_processor import FlowProcessor from . consumer_spec import ConsumerSpec @@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-query" +default_ident = "doc-embeddings-query" class DocumentEmbeddingsQueryService(FlowProcessor): diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index e25d76c7..07eb2bc7 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -2,15 +2,21 @@ import logging from . request_response_spec import RequestResponse, RequestResponseSpec -from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, IRI, LITERAL from .. knowledge import Uri, Literal # Module logger logger = logging.getLogger(__name__) + def to_value(x): - if x.is_uri: return Uri(x.value) - return Literal(x.value) + """Convert schema Term to Uri or Literal.""" + if x.type == IRI: + return Uri(x.iri) + elif x.type == LITERAL: + return Literal(x.value) + # Fallback + return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): async def query(self, vectors, limit=20, user="trustgraph", diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index f3afdba2..d429b3a5 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -7,7 +7,7 @@ embeddings. import logging from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse -from .. schema import Error, Value +from .. schema import Error, Term from . flow_processor import FlowProcessor from . consumer_spec import ConsumerSpec @@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-query" +default_ident = "graph-embeddings-query" class GraphEmbeddingsQueryService(FlowProcessor): diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py new file mode 100644 index 00000000..0141da31 --- /dev/null +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -0,0 +1,45 @@ +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse + +class RowEmbeddingsQueryClient(RequestResponse): + async def row_embeddings_query( + self, vectors, schema_name, user="trustgraph", collection="default", + index_name=None, limit=10, timeout=600 + ): + request = RowEmbeddingsRequest( + vectors=vectors, + schema_name=schema_name, + user=user, + collection=collection, + limit=limit + ) + if index_name: + request.index_name = index_name + + resp = await self.request(request, timeout=timeout) + + if resp.error: + raise RuntimeError(resp.error.message) + + # Return matches as list of dicts + return [ + { + "index_name": match.index_name, + "index_value": match.index_value, + "text": match.text, + "score": match.score + } + for match in (resp.matches or []) + ] + +class RowEmbeddingsQueryClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(RowEmbeddingsQueryClientSpec, self).__init__( + request_name = request_name, + request_schema = RowEmbeddingsRequest, + response_name = response_name, + response_schema = RowEmbeddingsResponse, + impl = RowEmbeddingsQueryClient, + ) diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index d59bcab3..b0d90507 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -222,35 +222,50 @@ class Subscriber: # Store message for later acknowledgment msg_id = str(uuid.uuid4()) self.pending_acks[msg_id] = msg - + try: id = msg.properties()["id"] except: id = None - + value = msg.value() delivery_success = False - + has_matching_waiter = False + async with self.lock: # Deliver to specific subscribers if id in self.q: + has_matching_waiter = True delivery_success = await self._deliver_to_queue( self.q[id], value ) - + # Deliver to all subscribers for q in self.full.values(): + has_matching_waiter = True if await self._deliver_to_queue(q, value): delivery_success = True - - # Acknowledge only on successful delivery - if delivery_success: - self.consumer.acknowledge(msg) - del self.pending_acks[msg_id] - else: - # Negative acknowledge for retry - self.consumer.negative_acknowledge(msg) - del self.pending_acks[msg_id] + + # Always acknowledge the message to prevent redelivery storms + # on shared topics. Negative acknowledging orphaned messages + # (no matching waiter) causes immediate redelivery to all + # subscribers, none of whom can handle it either. + self.consumer.acknowledge(msg) + del self.pending_acks[msg_id] + + if not delivery_success: + if not has_matching_waiter: + # Message arrived for a waiter that no longer exists + # (likely due to client disconnect or timeout) + logger.debug( + f"Discarding orphaned message with id={id} - " + "no matching waiter" + ) + else: + # Delivery failed (e.g., queue full with drop_new strategy) + logger.debug( + f"Message with id={id} dropped due to backpressure" + ) async def _deliver_to_queue(self, queue, value): """Deliver message to queue with backpressure handling""" diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index c9f747b5..7258d3ca 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,24 +1,34 @@ from . request_response_spec import RequestResponse, RequestResponseSpec -from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL from .. knowledge import Uri, Literal + class Triple: def __init__(self, s, p, o): self.s = s self.p = p self.o = o + def to_value(x): - if x.is_uri: return Uri(x.value) - return Literal(x.value) + """Convert schema Term to Uri or Literal.""" + if x.type == IRI: + return Uri(x.iri) + elif x.type == LITERAL: + return Literal(x.value) + # Fallback + return Literal(x.value or x.iri) + def from_value(x): - if x is None: return None + """Convert Uri or Literal to schema Term.""" + if x is None: + return None if isinstance(x, Uri): - return Value(value=str(x), is_uri=True) + return Term(type=IRI, iri=str(x)) else: - return Value(value=str(x), is_uri=False) + return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): async def query(self, s=None, p=None, o=None, limit=20, diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index 0d8affcb..b156ef55 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -7,7 +7,7 @@ null. Output is a list of triples. import logging from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .. schema import Value, Triple +from .. schema import Term, Triple from . flow_processor import FlowProcessor from . consumer_spec import ConsumerSpec diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py new file mode 100644 index 00000000..4f911e3c --- /dev/null +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -0,0 +1,60 @@ + +import _pulsar + +from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse +from .. schema import row_embeddings_request_queue +from .. schema import row_embeddings_response_queue +from . base import BaseClient + +# Ugly +ERROR=_pulsar.LoggerLevel.Error +WARN=_pulsar.LoggerLevel.Warn +INFO=_pulsar.LoggerLevel.Info +DEBUG=_pulsar.LoggerLevel.Debug + +class RowEmbeddingsClient(BaseClient): + + def __init__( + self, log_level=ERROR, + subscriber=None, + input_queue=None, + output_queue=None, + pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, + ): + + if input_queue == None: + input_queue = row_embeddings_request_queue + + if output_queue == None: + output_queue = row_embeddings_response_queue + + super(RowEmbeddingsClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, + input_schema=RowEmbeddingsRequest, + output_schema=RowEmbeddingsResponse, + ) + + def request( + self, vectors, schema_name, user="trustgraph", collection="default", + index_name=None, limit=10, timeout=300 + ): + kwargs = dict( + user=user, collection=collection, + vectors=vectors, schema_name=schema_name, + limit=limit, timeout=timeout + ) + if index_name: + kwargs["index_name"] = index_name + + response = self.call(**kwargs) + + if response.error: + raise RuntimeError(f"{response.error.type}: {response.error.message}") + + return response.matches diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 8ed2ebb7..401aaf0b 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -2,7 +2,7 @@ import _pulsar -from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL from .. schema import triples_request_queue from .. schema import triples_response_queue from . base import BaseClient @@ -46,9 +46,9 @@ class TriplesQueryClient(BaseClient): if ent == None: return None if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) def request( self, diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 80c5438b..9fbcbf16 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -19,9 +19,10 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato from .translators.tool import ToolRequestTranslator, ToolResponseTranslator from .translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, - GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator + GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator, + RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator ) -from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator +from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator @@ -107,15 +108,21 @@ TranslatorRegistry.register_service( ) TranslatorRegistry.register_service( - "graph-embeddings-query", - GraphEmbeddingsRequestTranslator(), + "graph-embeddings-query", + GraphEmbeddingsRequestTranslator(), GraphEmbeddingsResponseTranslator() ) TranslatorRegistry.register_service( - "objects-query", - ObjectsQueryRequestTranslator(), - ObjectsQueryResponseTranslator() + "row-embeddings-query", + RowEmbeddingsRequestTranslator(), + RowEmbeddingsResponseTranslator() +) + +TranslatorRegistry.register_service( + "rows-query", + RowsQueryRequestTranslator(), + RowsQueryResponseTranslator() ) TranslatorRegistry.register_service( diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py index 9ce2730e..5b5820fa 100644 --- a/trustgraph-base/trustgraph/messaging/translators/__init__.py +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -1,5 +1,5 @@ from .base import Translator, MessageTranslator -from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator +from .primitives import TermTranslator, ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator from .agent import AgentRequestTranslator, AgentResponseTranslator from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator @@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator from .prompt import PromptRequestTranslator, PromptResponseTranslator from .embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, - GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator + GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator, + RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator ) -from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator +from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index a08f9b6c..141a7330 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -1,7 +1,8 @@ from typing import Dict, Any, Tuple from ...schema import ( DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, - GraphEmbeddingsRequest, GraphEmbeddingsResponse + GraphEmbeddingsRequest, GraphEmbeddingsResponse, + RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch ) from .base import MessageTranslator from .primitives import ValueTranslator @@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True + + +class RowEmbeddingsRequestTranslator(MessageTranslator): + """Translator for RowEmbeddingsRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: + return RowEmbeddingsRequest( + vectors=data["vectors"], + limit=int(data.get("limit", 10)), + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + schema_name=data.get("schema_name", ""), + index_name=data.get("index_name") + ) + + def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: + result = { + "vectors": obj.vectors, + "limit": obj.limit, + "user": obj.user, + "collection": obj.collection, + "schema_name": obj.schema_name, + } + if obj.index_name: + result["index_name"] = obj.index_name + return result + + +class RowEmbeddingsResponseTranslator(MessageTranslator): + """Translator for RowEmbeddingsResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]: + result = {} + + if obj.error is not None: + result["error"] = { + "type": obj.error.type, + "message": obj.error.message + } + + if obj.matches is not None: + result["matches"] = [ + { + "index_name": match.index_name, + "index_value": match.index_value, + "text": match.text, + "score": match.score + } + for match in obj.matches + ] + + return result + + def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/primitives.py b/trustgraph-base/trustgraph/messaging/translators/primitives.py index 42db4151..790ae8f7 100644 --- a/trustgraph-base/trustgraph/messaging/translators/primitives.py +++ b/trustgraph-base/trustgraph/messaging/translators/primitives.py @@ -1,37 +1,133 @@ from typing import Dict, Any, List -from ...schema import Value, Triple, RowSchema, Field +from ...schema import Term, Triple, RowSchema, Field, IRI, BLANK, LITERAL, TRIPLE from .base import Translator -class ValueTranslator(Translator): - """Translator for Value schema objects""" - - def to_pulsar(self, data: Dict[str, Any]) -> Value: - return Value(value=data["v"], is_uri=data["e"]) - - def from_pulsar(self, obj: Value) -> Dict[str, Any]: - return {"v": obj.value, "e": obj.is_uri} +class TermTranslator(Translator): + """ + Translator for Term schema objects. + + Wire format (compact keys): + - "t": type (i/b/l/t) + - "i": iri (for IRI type) + - "d": id (for BLANK type) + - "v": value (for LITERAL type) + - "dt": datatype (for LITERAL type) + - "ln": language (for LITERAL type) + - "tr": triple (for TRIPLE type, nested) + """ + + def to_pulsar(self, data: Dict[str, Any]) -> Term: + term_type = data.get("t", "") + + if term_type == IRI: + return Term(type=IRI, iri=data.get("i", "")) + + elif term_type == BLANK: + return Term(type=BLANK, id=data.get("d", "")) + + elif term_type == LITERAL: + return Term( + type=LITERAL, + value=data.get("v", ""), + datatype=data.get("dt", ""), + language=data.get("ln", ""), + ) + + elif term_type == TRIPLE: + # Nested triple - use TripleTranslator + triple_data = data.get("tr") + if triple_data: + triple = _triple_translator_to_pulsar(triple_data) + else: + triple = None + return Term(type=TRIPLE, triple=triple) + + else: + # Unknown or empty type + return Term(type=term_type) + + def from_pulsar(self, obj: Term) -> Dict[str, Any]: + result: Dict[str, Any] = {"t": obj.type} + + if obj.type == IRI: + result["i"] = obj.iri + + elif obj.type == BLANK: + result["d"] = obj.id + + elif obj.type == LITERAL: + result["v"] = obj.value + if obj.datatype: + result["dt"] = obj.datatype + if obj.language: + result["ln"] = obj.language + + elif obj.type == TRIPLE: + if obj.triple: + result["tr"] = _triple_translator_from_pulsar(obj.triple) + + return result + + +# Module-level helper functions to avoid circular instantiation +def _triple_translator_to_pulsar(data: Dict[str, Any]) -> Triple: + term_translator = TermTranslator() + return Triple( + s=term_translator.to_pulsar(data["s"]) if data.get("s") else None, + p=term_translator.to_pulsar(data["p"]) if data.get("p") else None, + o=term_translator.to_pulsar(data["o"]) if data.get("o") else None, + g=data.get("g"), + ) + + +def _triple_translator_from_pulsar(obj: Triple) -> Dict[str, Any]: + term_translator = TermTranslator() + result: Dict[str, Any] = {} + + if obj.s: + result["s"] = term_translator.from_pulsar(obj.s) + if obj.p: + result["p"] = term_translator.from_pulsar(obj.p) + if obj.o: + result["o"] = term_translator.from_pulsar(obj.o) + if obj.g: + result["g"] = obj.g + + return result class TripleTranslator(Translator): - """Translator for Triple schema objects""" - + """Translator for Triple schema objects (quads with optional graph)""" + def __init__(self): - self.value_translator = ValueTranslator() - + self.term_translator = TermTranslator() + def to_pulsar(self, data: Dict[str, Any]) -> Triple: return Triple( - s=self.value_translator.to_pulsar(data["s"]), - p=self.value_translator.to_pulsar(data["p"]), - o=self.value_translator.to_pulsar(data["o"]) + s=self.term_translator.to_pulsar(data["s"]) if data.get("s") else None, + p=self.term_translator.to_pulsar(data["p"]) if data.get("p") else None, + o=self.term_translator.to_pulsar(data["o"]) if data.get("o") else None, + g=data.get("g"), ) - + def from_pulsar(self, obj: Triple) -> Dict[str, Any]: - return { - "s": self.value_translator.from_pulsar(obj.s), - "p": self.value_translator.from_pulsar(obj.p), - "o": self.value_translator.from_pulsar(obj.o) - } + result: Dict[str, Any] = {} + + if obj.s: + result["s"] = self.term_translator.from_pulsar(obj.s) + if obj.p: + result["p"] = self.term_translator.from_pulsar(obj.p) + if obj.o: + result["o"] = self.term_translator.from_pulsar(obj.o) + if obj.g: + result["g"] = obj.g + + return result + + +# Backward compatibility alias +ValueTranslator = TermTranslator class SubgraphTranslator(Translator): diff --git a/trustgraph-base/trustgraph/messaging/translators/objects_query.py b/trustgraph-base/trustgraph/messaging/translators/rows_query.py similarity index 68% rename from trustgraph-base/trustgraph/messaging/translators/objects_query.py rename to trustgraph-base/trustgraph/messaging/translators/rows_query.py index a746e0c7..6feb75a3 100644 --- a/trustgraph-base/trustgraph/messaging/translators/objects_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/rows_query.py @@ -1,44 +1,44 @@ from typing import Dict, Any, Tuple, Optional -from ...schema import ObjectsQueryRequest, ObjectsQueryResponse +from ...schema import RowsQueryRequest, RowsQueryResponse from .base import MessageTranslator import json -class ObjectsQueryRequestTranslator(MessageTranslator): - """Translator for ObjectsQueryRequest schema objects""" - - def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryRequest: - return ObjectsQueryRequest( +class RowsQueryRequestTranslator(MessageTranslator): + """Translator for RowsQueryRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryRequest: + return RowsQueryRequest( user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), variables=data.get("variables", {}), operation_name=data.get("operation_name", None) ) - - def from_pulsar(self, obj: ObjectsQueryRequest) -> Dict[str, Any]: + + def from_pulsar(self, obj: RowsQueryRequest) -> Dict[str, Any]: result = { "user": obj.user, "collection": obj.collection, "query": obj.query, "variables": dict(obj.variables) if obj.variables else {} } - + if obj.operation_name: result["operation_name"] = obj.operation_name - + return result -class ObjectsQueryResponseTranslator(MessageTranslator): - """Translator for ObjectsQueryResponse schema objects""" - - def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryResponse: +class RowsQueryResponseTranslator(MessageTranslator): + """Translator for RowsQueryResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - - def from_pulsar(self, obj: ObjectsQueryResponse) -> Dict[str, Any]: + + def from_pulsar(self, obj: RowsQueryResponse) -> Dict[str, Any]: result = {} - + # Handle GraphQL response data if obj.data: try: @@ -47,7 +47,7 @@ class ObjectsQueryResponseTranslator(MessageTranslator): result["data"] = obj.data else: result["data"] = None - + # Handle GraphQL errors if obj.errors: result["errors"] = [] @@ -60,20 +60,20 @@ class ObjectsQueryResponseTranslator(MessageTranslator): if error.extensions: error_dict["extensions"] = dict(error.extensions) result["errors"].append(error_dict) - + # Handle extensions if obj.extensions: result["extensions"] = dict(obj.extensions) - + # Handle system-level error if obj.error: result["error"] = { "type": obj.error.type, "message": obj.error.message } - + return result - - def from_response_with_completion(self, obj: ObjectsQueryResponse) -> Tuple[Dict[str, Any], bool]: + + def from_response_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.from_pulsar(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 1c08625b..2b01b1bc 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -14,11 +14,13 @@ class TriplesQueryRequestTranslator(MessageTranslator): s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None - + g = data.get("g") # None=default graph, "*"=all graphs + return TriplesQueryRequest( s=s, p=p, o=o, + g=g, limit=int(data.get("limit", 10000)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") @@ -30,14 +32,16 @@ class TriplesQueryRequestTranslator(MessageTranslator): "user": obj.user, "collection": obj.collection } - + if obj.s: result["s"] = self.value_translator.from_pulsar(obj.s) if obj.p: result["p"] = self.value_translator.from_pulsar(obj.p) if obj.o: result["o"] = self.value_translator.from_pulsar(obj.o) - + if obj.g is not None: + result["g"] = obj.g + return result diff --git a/trustgraph-base/trustgraph/schema/core/primitives.py b/trustgraph-base/trustgraph/schema/core/primitives.py index 02517614..78676eb0 100644 --- a/trustgraph-base/trustgraph/schema/core/primitives.py +++ b/trustgraph-base/trustgraph/schema/core/primitives.py @@ -1,22 +1,57 @@ - from dataclasses import dataclass, field +# Term type constants +IRI = "i" # IRI/URI node +BLANK = "b" # Blank node +LITERAL = "l" # Literal value +TRIPLE = "t" # Quoted triple (RDF-star) + + @dataclass class Error: type: str = "" message: str = "" + @dataclass -class Value: +class Term: + """ + RDF Term - can represent an IRI, blank node, literal, or quoted triple. + + The 'type' field determines which other fields are relevant: + - IRI: use 'iri' field + - BLANK: use 'id' field + - LITERAL: use 'value', 'datatype', 'language' fields + - TRIPLE: use 'triple' field + """ + type: str = "" # One of: IRI, BLANK, LITERAL, TRIPLE + + # For IRI terms (type == IRI) + iri: str = "" + + # For blank nodes (type == BLANK) + id: str = "" + + # For literals (type == LITERAL) value: str = "" - is_uri: bool = False - type: str = "" + datatype: str = "" # XSD datatype URI (mutually exclusive with language) + language: str = "" # Language tag (mutually exclusive with datatype) + + # For quoted triples (type == TRIPLE) + triple: "Triple | None" = None + @dataclass class Triple: - s: Value | None = None - p: Value | None = None - o: Value | None = None + """ + RDF Triple / Quad. + + The optional 'g' field specifies the named graph (None = default graph). + """ + s: Term | None = None # Subject + p: Term | None = None # Predicate + o: Term | None = None # Object + g: str | None = None # Graph name (IRI), None = default graph @dataclass class Field: diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index a3e5b394..93559056 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from ..core.metadata import Metadata -from ..core.primitives import Value, RowSchema +from ..core.primitives import Term, RowSchema from ..core.topic import topic ############################################################################ @@ -10,7 +10,7 @@ from ..core.topic import topic @dataclass class EntityEmbeddings: - entity: Value | None = None + entity: Term | None = None vectors: list[list[float]] = field(default_factory=list) # This is a 'batching' mechanism for the above data @@ -60,3 +60,23 @@ class StructuredObjectEmbedding: field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings ############################################################################ + +# Row embeddings are embeddings associated with indexed field values +# in structured row data. Each index gets embedded separately. + +@dataclass +class RowIndexEmbedding: + """Single row's embedding for one index""" + index_name: str = "" # The indexed field name(s) + index_value: list[str] = field(default_factory=list) # The field value(s) + text: str = "" # Text that was embedded + vectors: list[list[float]] = field(default_factory=list) + +@dataclass +class RowEmbeddings: + """Batched row embeddings for a schema""" + metadata: Metadata | None = None + schema_name: str = "" + embeddings: list[RowIndexEmbedding] = field(default_factory=list) + +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/graph.py b/trustgraph-base/trustgraph/schema/knowledge/graph.py index 9040c25e..4ee8d2c0 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/graph.py +++ b/trustgraph-base/trustgraph/schema/knowledge/graph.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from ..core.primitives import Value, Triple +from ..core.primitives import Term, Triple from ..core.metadata import Metadata from ..core.topic import topic @@ -10,7 +10,7 @@ from ..core.topic import topic @dataclass class EntityContext: - entity: Value | None = None + entity: Term | None = None context: str = "" # This is a 'batching' mechanism for the above data diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index aaeb739f..7b40ca0a 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -9,7 +9,7 @@ from .library import * from .lookup import * from .nlp_query import * from .structured_query import * -from .objects_query import * +from .rows_query import * from .diagnosis import * from .collection import * from .storage import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/lookup.py b/trustgraph-base/trustgraph/schema/services/lookup.py index bdeac636..d944fb89 100644 --- a/trustgraph-base/trustgraph/schema/services/lookup.py +++ b/trustgraph-base/trustgraph/schema/services/lookup.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from ..core.primitives import Error, Value, Triple +from ..core.primitives import Error, Term, Triple from ..core.topic import topic from ..core.metadata import Metadata diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 31d0852d..50ec416a 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from ..core.primitives import Error, Value, Triple +from ..core.primitives import Error, Term, Triple from ..core.topic import topic ############################################################################ @@ -17,7 +17,7 @@ class GraphEmbeddingsRequest: @dataclass class GraphEmbeddingsResponse: error: Error | None = None - entities: list[Value] = field(default_factory=list) + entities: list[Term] = field(default_factory=list) ############################################################################ @@ -27,9 +27,10 @@ class GraphEmbeddingsResponse: class TriplesQueryRequest: user: str = "" collection: str = "" - s: Value | None = None - p: Value | None = None - o: Value | None = None + s: Term | None = None + p: Term | None = None + o: Term | None = None + g: str | None = None # Graph IRI. None=default graph, "*"=all graphs limit: int = 0 @dataclass @@ -58,4 +59,39 @@ document_embeddings_request_queue = topic( ) document_embeddings_response_queue = topic( "document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' +) + +############################################################################ + +# Row embeddings query - for semantic/fuzzy matching on row index values + +@dataclass +class RowIndexMatch: + """A single matching row index from a semantic search""" + index_name: str = "" # The indexed field(s) + index_value: list[str] = field(default_factory=list) # The index values + text: str = "" # The text that was embedded + score: float = 0.0 # Similarity score + +@dataclass +class RowEmbeddingsRequest: + """Request for row embeddings semantic search""" + vectors: list[list[float]] = field(default_factory=list) # Query vectors + limit: int = 10 # Max results to return + user: str = "" # User/keyspace + collection: str = "" # Collection name + schema_name: str = "" # Schema name to search within + index_name: str | None = None # Optional: filter to specific index + +@dataclass +class RowEmbeddingsResponse: + """Response from row embeddings semantic search""" + error: Error | None = None + matches: list[RowIndexMatch] = field(default_factory=list) + +row_embeddings_request_queue = topic( + "row-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' +) +row_embeddings_response_queue = topic( + "row-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 72085ae8..4337cb9b 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from ..core.topic import topic -from ..core.primitives import Error, Value +from ..core.primitives import Error, Term ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/objects_query.py b/trustgraph-base/trustgraph/schema/services/rows_query.py similarity index 91% rename from trustgraph-base/trustgraph/schema/services/objects_query.py rename to trustgraph-base/trustgraph/schema/services/rows_query.py index e24daef3..a4818329 100644 --- a/trustgraph-base/trustgraph/schema/services/objects_query.py +++ b/trustgraph-base/trustgraph/schema/services/rows_query.py @@ -6,7 +6,7 @@ from ..core.topic import topic ############################################################################ -# Objects Query Service - executes GraphQL queries against structured data +# Rows Query Service - executes GraphQL queries against structured data @dataclass class GraphQLError: @@ -15,7 +15,7 @@ class GraphQLError: extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata @dataclass -class ObjectsQueryRequest: +class RowsQueryRequest: user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest) collection: str = "" # Data collection identifier (required for partition key) query: str = "" # GraphQL query string @@ -23,7 +23,7 @@ class ObjectsQueryRequest: operation_name: Optional[str] = None # Operation to execute for multi-operation documents @dataclass -class ObjectsQueryResponse: +class RowsQueryResponse: error: Error | None = None # System-level error (connection, timeout, etc.) data: str = "" # JSON-encoded GraphQL response data errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index c9192794..4e093953 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 5568bf91..66df74f1 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", "requests", "pulsar-client", "aiohttp", @@ -43,9 +43,13 @@ tg-invoke-agent = "trustgraph.cli.invoke_agent:main" tg-invoke-document-rag = "trustgraph.cli.invoke_document_rag:main" tg-invoke-graph-rag = "trustgraph.cli.invoke_graph_rag:main" tg-invoke-llm = "trustgraph.cli.invoke_llm:main" +tg-invoke-embeddings = "trustgraph.cli.invoke_embeddings:main" +tg-invoke-graph-embeddings = "trustgraph.cli.invoke_graph_embeddings:main" +tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main" tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main" tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main" -tg-invoke-objects-query = "trustgraph.cli.invoke_objects_query:main" +tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" +tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py new file mode 100644 index 00000000..b14397cb --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -0,0 +1,121 @@ +""" +Queries document chunks by text similarity using vector embeddings. +Returns a list of matching document chunks, truncated to the specified length. +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + +def truncate_chunk(chunk, max_length): + """Truncate a chunk to max_length characters, adding ellipsis if needed.""" + if len(chunk) <= max_length: + return chunk + return chunk[:max_length] + "..." + +def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None): + + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + # Call document embeddings query service + result = flow.document_embeddings_query( + text=query_text, + user=user, + collection=collection, + limit=limit + ) + + chunks = result.get("chunks", []) + for i, chunk in enumerate(chunks, 1): + truncated = truncate_chunk(chunk, max_chunk_length) + print(f"{i}. {truncated}") + + finally: + # Clean up socket connection + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-document-embeddings', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-U', '--user', + default="trustgraph", + help='User/keyspace (default: trustgraph)', + ) + + parser.add_argument( + '-c', '--collection', + default="default", + help='Collection (default: default)', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + '--max-chunk-length', + type=int, + default=200, + help='Truncate chunks to N characters (default: 200)', + ) + + parser.add_argument( + 'query', + nargs=1, + help='Query text to search for similar document chunks', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + query_text=args.query[0], + user=args.user, + collection=args.collection, + limit=args.limit, + max_chunk_length=args.max_chunk_length, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py new file mode 100644 index 00000000..71a88bd7 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -0,0 +1,77 @@ +""" +Invokes the embeddings service to convert text to a vector embedding. +Returns the embedding vector as a list of floats. +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + +def query(url, flow_id, text, token=None): + + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + # Call embeddings service + result = flow.embeddings(text=text) + vectors = result.get("vectors", []) + print(vectors) + + finally: + # Clean up socket connection + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-embeddings', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + 'text', + nargs=1, + help='Text to convert to embedding vector', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + text=args.text[0], + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py new file mode 100644 index 00000000..ae195007 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py @@ -0,0 +1,106 @@ +""" +Queries graph entities by text similarity using vector embeddings. +Returns a list of matching graph entities. +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + +def query(url, flow_id, query_text, user, collection, limit, token=None): + + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + # Call graph embeddings query service + result = flow.graph_embeddings_query( + text=query_text, + user=user, + collection=collection, + limit=limit + ) + + entities = result.get("entities", []) + for entity in entities: + print(entity) + + finally: + # Clean up socket connection + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-graph-embeddings', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-U', '--user', + default="trustgraph", + help='User/keyspace (default: trustgraph)', + ) + + parser.add_argument( + '-c', '--collection', + default="default", + help='Collection (default: default)', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + 'query', + nargs=1, + help='Query text to search for similar graph entities', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + query_text=args.query[0], + user=args.user, + collection=args.collection, + limit=args.limit, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py new file mode 100644 index 00000000..7393b4c3 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py @@ -0,0 +1,126 @@ +""" +Queries row data by text similarity using vector embeddings on indexed fields. +Returns matching rows with their index values and similarity scores. +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + +def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None): + + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + # Call row embeddings query service + result = flow.row_embeddings_query( + text=query_text, + schema_name=schema_name, + user=user, + collection=collection, + index_name=index_name, + limit=limit + ) + + matches = result.get("matches", []) + for match in matches: + print(f"Index: {match['index_name']}") + print(f" Values: {match['index_value']}") + print(f" Text: {match['text']}") + print(f" Score: {match['score']:.4f}") + print() + + finally: + # Clean up socket connection + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-row-embeddings', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-U', '--user', + default="trustgraph", + help='User/keyspace (default: trustgraph)', + ) + + parser.add_argument( + '-c', '--collection', + default="default", + help='Collection (default: default)', + ) + + parser.add_argument( + '-s', '--schema-name', + required=True, + help='Schema name to search within (required)', + ) + + parser.add_argument( + '-i', '--index-name', + default=None, + help='Index name to filter search (optional)', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + 'query', + nargs=1, + help='Query text to search for similar row index values', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + query_text=args.query[0], + schema_name=args.schema_name, + user=args.user, + collection=args.collection, + index_name=args.index_name, + limit=args.limit, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_objects_query.py b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py similarity index 96% rename from trustgraph-cli/trustgraph/cli/invoke_objects_query.py rename to trustgraph-cli/trustgraph/cli/invoke_rows_query.py index 50c4e8c2..962f353c 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_objects_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py @@ -1,5 +1,5 @@ """ -Uses the ObjectsQuery service to execute GraphQL queries against structured data +Uses the RowsQuery service to execute GraphQL queries against structured data """ import argparse @@ -81,7 +81,7 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def objects_query( +def rows_query( url, flow_id, query, user, collection, variables, operation_name, output_format='table' ): @@ -96,7 +96,7 @@ def objects_query( print(f"Error parsing variables JSON: {e}", file=sys.stderr) sys.exit(1) - resp = api.objects_query( + resp = api.rows_query( query=query, user=user, collection=collection, @@ -126,7 +126,7 @@ def objects_query( def main(): parser = argparse.ArgumentParser( - prog='tg-invoke-objects-query', + prog='tg-invoke-rows-query', description=__doc__, ) @@ -181,7 +181,7 @@ def main(): try: - objects_query( + rows_query( url=args.url, flow_id=args.flow_id, query=args.query, diff --git a/trustgraph-cli/trustgraph/cli/load_knowledge.py b/trustgraph-cli/trustgraph/cli/load_knowledge.py index ff6ca980..5e96850f 100644 --- a/trustgraph-cli/trustgraph/cli/load_knowledge.py +++ b/trustgraph-cli/trustgraph/cli/load_knowledge.py @@ -87,13 +87,20 @@ class KnowledgeLoader: # Load triples from all files print("Loading triples...") + total_triples = 0 for file in self.files: print(f" Processing {file}...") - triples = self.load_triples_from_file(file) + count = 0 + + def counting_triples(): + nonlocal count + for triple in self.load_triples_from_file(file): + count += 1 + yield triple bulk.import_triples( flow=self.flow, - triples=triples, + triples=counting_triples(), metadata={ "id": self.document_id, "metadata": [], @@ -101,25 +108,33 @@ class KnowledgeLoader: "collection": self.collection } ) + print(f" Loaded {count} triples") + total_triples += count - print("Triples loaded.") + print(f"Triples loaded. Total: {total_triples}") # Load entity contexts from all files print("Loading entity contexts...") + total_contexts = 0 for file in self.files: print(f" Processing {file}...") + count = 0 # Convert tuples to the format expected by import_entity_contexts + # Entity must be in Term format: {"t": "i", "i": uri} for IRI def entity_context_generator(): + nonlocal count for entity, context in self.load_entity_contexts_from_file(file): + count += 1 + # Entities from RDF are URIs, use IRI term format yield { - "entity": {"v": entity, "e": True}, + "entity": {"t": "i", "i": entity}, "context": context } bulk.import_entity_contexts( flow=self.flow, - entities=entity_context_generator(), + contexts=entity_context_generator(), metadata={ "id": self.document_id, "metadata": [], @@ -127,8 +142,10 @@ class KnowledgeLoader: "collection": self.collection } ) + print(f" Loaded {count} entity contexts") + total_contexts += count - print("Entity contexts loaded.") + print(f"Entity contexts loaded. Total: {total_contexts}") except Exception as e: print(f"Error: {e}", flush=True) diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index bf112417..fa167917 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -573,19 +573,19 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample return output_records, descriptor -def _send_to_trustgraph(objects, api_url, flow, batch_size=1000, token=None): +def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): """Send ExtractedObject records to TrustGraph using Python API""" from trustgraph.api import Api try: - total_records = len(objects) + total_records = len(rows) logger.info(f"Importing {total_records} records to TrustGraph...") # Use Python API bulk import api = Api(api_url, token=token) bulk = api.bulk() - bulk.import_objects(flow=flow, objects=iter(objects)) + bulk.import_rows(flow=flow, rows=iter(rows)) logger.info(f"Successfully imported {total_records} records to TrustGraph") diff --git a/trustgraph-cli/trustgraph/cli/set_tool.py b/trustgraph-cli/trustgraph/cli/set_tool.py index 36701a8e..c6412e48 100644 --- a/trustgraph-cli/trustgraph/cli/set_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_tool.py @@ -2,8 +2,9 @@ Configures and registers tools in the TrustGraph system. This script allows you to define agent tools with various types including: -- knowledge-query: Query knowledge bases +- knowledge-query: Query knowledge bases - structured-query: Query structured data using natural language +- row-embeddings-query: Semantic search on structured data indexes - text-completion: Text generation - mcp-tool: Reference to MCP (Model Context Protocol) tools - prompt: Prompt template execution @@ -64,6 +65,9 @@ def set_tool( mcp_tool : str, collection : str, template : str, + schema_name : str, + index_name : str, + limit : int, arguments : List[Argument], group : List[str], state : str, @@ -89,6 +93,12 @@ def set_tool( if template: object["template"] = template + if schema_name: object["schema-name"] = schema_name + + if index_name: object["index-name"] = index_name + + if limit: object["limit"] = limit + if arguments: object["arguments"] = [ { @@ -120,30 +130,37 @@ def main(): description=__doc__, epilog=textwrap.dedent(''' Valid tool types: - knowledge-query - Query knowledge bases (fixed args) - structured-query - Query structured data using natural language (fixed args) - text-completion - Text completion/generation (fixed args) - mcp-tool - Model Control Protocol tool (configurable args) - prompt - Prompt template query (configurable args) - - Note: Tools marked "(fixed args)" have predefined arguments and don't need + knowledge-query - Query knowledge bases (fixed args) + structured-query - Query structured data using natural language (fixed args) + row-embeddings-query - Semantic search on structured data indexes (fixed args) + text-completion - Text completion/generation (fixed args) + mcp-tool - Model Control Protocol tool (configurable args) + prompt - Prompt template query (configurable args) + + Note: Tools marked "(fixed args)" have predefined arguments and don't need --argument specified. Tools marked "(configurable args)" require --argument. - + Valid argument types: - string - String/text parameter + string - String/text parameter number - Numeric parameter - + Examples: %(prog)s --id weather_tool --name get_weather \\ --type knowledge-query \\ --description "Get weather information for a location" \\ --collection weather_data - + %(prog)s --id data_query_tool --name query_data \\ --type structured-query \\ --description "Query structured data using natural language" \\ --collection sales_data - + + %(prog)s --id customer_search --name find_customer \\ + --type row-embeddings-query \\ + --description "Find customers by name using semantic search" \\ + --schema-name customers --collection sales \\ + --index-name full_name --limit 20 + %(prog)s --id calc_tool --name calculate --type mcp-tool \\ --description "Perform mathematical calculations" \\ --mcp-tool calculator \\ @@ -181,7 +198,7 @@ def main(): parser.add_argument( '--type', - help=f'Tool type, one of: knowledge-query, structured-query, text-completion, mcp-tool, prompt', + help=f'Tool type, one of: knowledge-query, structured-query, row-embeddings-query, text-completion, mcp-tool, prompt', ) parser.add_argument( @@ -191,7 +208,23 @@ def main(): parser.add_argument( '--collection', - help=f'For knowledge-query and structured-query types: collection to query', + help=f'For knowledge-query, structured-query, and row-embeddings-query types: collection to query', + ) + + parser.add_argument( + '--schema-name', + help=f'For row-embeddings-query type: schema name to search within (required)', + ) + + parser.add_argument( + '--index-name', + help=f'For row-embeddings-query type: specific index to filter search (optional)', + ) + + parser.add_argument( + '--limit', + type=int, + help=f'For row-embeddings-query type: maximum results to return (default: 10)', ) parser.add_argument( @@ -227,7 +260,8 @@ def main(): try: valid_types = [ - "knowledge-query", "structured-query", "text-completion", "mcp-tool", "prompt" + "knowledge-query", "structured-query", "row-embeddings-query", + "text-completion", "mcp-tool", "prompt" ] if args.id is None: @@ -261,6 +295,9 @@ def main(): mcp_tool=mcp_tool, collection=args.collection, template=args.template, + schema_name=args.schema_name, + index_name=args.index_name, + limit=args.limit, arguments=arguments, group=args.group, state=args.state, diff --git a/trustgraph-cli/trustgraph/cli/show_tools.py b/trustgraph-cli/trustgraph/cli/show_tools.py index b8c9a012..d77f1fae 100644 --- a/trustgraph-cli/trustgraph/cli/show_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_tools.py @@ -4,8 +4,9 @@ Displays the current agent tool configurations Shows all configured tools including their types: - knowledge-query: Tools that query knowledge bases - structured-query: Tools that query structured data using natural language +- row-embeddings-query: Tools for semantic search on structured data indexes - text-completion: Tools for text generation -- mcp-tool: References to MCP (Model Context Protocol) tools +- mcp-tool: References to MCP (Model Context Protocol) tools - prompt: Tools that execute prompt templates """ @@ -41,11 +42,19 @@ def show_config(url, token=None): if tp == "mcp-tool": table.append(("mcp-tool", data["mcp-tool"])) - - if tp == "knowledge-query" or tp == "structured-query": + + if tp in ("knowledge-query", "structured-query", "row-embeddings-query"): if "collection" in data: table.append(("collection", data["collection"])) + if tp == "row-embeddings-query": + if "schema-name" in data: + table.append(("schema-name", data["schema-name"])) + if "index-name" in data: + table.append(("index-name", data["index-name"])) + if "limit" in data: + table.append(("limit", data["limit"])) + if tp == "prompt": table.append(("template", data["template"])) for n, arg in enumerate(data["arguments"]): diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 3d4fa65c..79e14540 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", - "trustgraph-flow>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", + "trustgraph-flow>=2.0,<2.1", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 70140147..31a22a2f 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", "aiohttp", "anthropic", "scylla-driver", @@ -19,7 +19,6 @@ dependencies = [ "faiss-cpu", "falkordb", "fastembed", - "google-genai", "ibis", "jsonschema", "langchain", @@ -61,27 +60,27 @@ api-gateway = "trustgraph.gateway:run" chunker-recursive = "trustgraph.chunking.recursive:run" chunker-token = "trustgraph.chunking.token:run" config-svc = "trustgraph.config.service:run" -de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" -de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" -de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" -de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run" -de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run" -de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run" +doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" +doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" +doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" +doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run" +doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run" +doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run" document-embeddings = "trustgraph.embeddings.document_embeddings:run" document-rag = "trustgraph.retrieval.document_rag:run" embeddings-fastembed = "trustgraph.embeddings.fastembed:run" embeddings-ollama = "trustgraph.embeddings.ollama:run" -ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run" -ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run" -ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run" -ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run" -ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run" -ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" +graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run" +graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run" +graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run" +graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run" +graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run" +graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" graph-embeddings = "trustgraph.embeddings.graph_embeddings:run" graph-rag = "trustgraph.retrieval.graph_rag:run" kg-extract-agent = "trustgraph.extract.kg.agent:run" kg-extract-definitions = "trustgraph.extract.kg.definitions:run" -kg-extract-objects = "trustgraph.extract.kg.objects:run" +kg-extract-rows = "trustgraph.extract.kg.rows:run" kg-extract-relationships = "trustgraph.extract.kg.relationships:run" kg-extract-topics = "trustgraph.extract.kg.topics:run" kg-extract-ontology = "trustgraph.extract.kg.ontology:run" @@ -91,8 +90,11 @@ librarian = "trustgraph.librarian:run" mcp-tool = "trustgraph.agent.mcp_tool:run" metering = "trustgraph.metering:run" nlp-query = "trustgraph.retrieval.nlp_query:run" -objects-write-cassandra = "trustgraph.storage.objects.cassandra:run" -objects-query-cassandra = "trustgraph.query.objects.cassandra:run" +rows-write-cassandra = "trustgraph.storage.rows.cassandra:run" +rows-query-cassandra = "trustgraph.query.rows.cassandra:run" +row-embeddings = "trustgraph.embeddings.row_embeddings:run" +row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run" +row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run" pdf-decoder = "trustgraph.decoding.pdf:run" pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run" prompt-template = "trustgraph.prompt.template:run" @@ -104,7 +106,6 @@ text-completion-azure = "trustgraph.model.text_completion.azure:run" text-completion-azure-openai = "trustgraph.model.text_completion.azure_openai:run" text-completion-claude = "trustgraph.model.text_completion.claude:run" text-completion-cohere = "trustgraph.model.text_completion.cohere:run" -text-completion-googleaistudio = "trustgraph.model.text_completion.googleaistudio:run" text-completion-llamafile = "trustgraph.model.text_completion.llamafile:run" text-completion-lmstudio = "trustgraph.model.text_completion.lmstudio:run" text-completion-mistral = "trustgraph.model.text_completion.mistral:run" diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 3af851d2..1a44ef9e 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -13,10 +13,11 @@ logger = logging.getLogger(__name__) from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec +from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... schema import AgentRequest, AgentResponse, AgentStep, Error -from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl +from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl from . agent_manager import AgentManager from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state @@ -87,6 +88,20 @@ class Processor(AgentService): ) ) + self.register_specification( + EmbeddingsClientSpec( + request_name = "embeddings-request", + response_name = "embeddings-response", + ) + ) + + self.register_specification( + RowEmbeddingsQueryClientSpec( + request_name = "row-embeddings-query-request", + response_name = "row-embeddings-query-response", + ) + ) + async def on_tools_config(self, config, version): logger.info(f"Loading configuration version {version}") @@ -147,11 +162,21 @@ class Processor(AgentService): ) elif impl_id == "structured-query": impl = functools.partial( - StructuredQueryImpl, + StructuredQueryImpl, collection=data.get("collection"), user=None # User will be provided dynamically via context ) arguments = StructuredQueryImpl.get_arguments() + elif impl_id == "row-embeddings-query": + impl = functools.partial( + RowEmbeddingsQueryImpl, + schema_name=data.get("schema-name"), + collection=data.get("collection"), + user=None, # User will be provided dynamically via context + index_name=data.get("index-name"), # Optional filter + limit=int(data.get("limit", 10)) # Max results + ) + arguments = RowEmbeddingsQueryImpl.get_arguments() else: raise RuntimeError( f"Tool type {impl_id} not known" @@ -327,11 +352,11 @@ class Processor(AgentService): def __init__(self, flow, user): self._flow = flow self._user = user - + def __call__(self, service_name): client = self._flow(service_name) - # For structured query clients, store user context - if service_name == "structured-query-request": + # For query clients that need user context, store it + if service_name in ("structured-query-request", "row-embeddings-query-request"): client._current_user = self._user return client diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index e32dc2d8..2b442a0d 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -128,6 +128,62 @@ class StructuredQueryImpl: return str(result) +# This tool implementation knows how to query row embeddings for semantic search +class RowEmbeddingsQueryImpl: + def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10): + self.context = context + self.schema_name = schema_name + self.collection = collection + self.user = user + self.index_name = index_name # Optional: filter to specific index + self.limit = limit # Max results to return + + @staticmethod + def get_arguments(): + return [ + Argument( + name="query", + type="string", + description="Text to search for semantically similar values in the structured data index" + ) + ] + + async def invoke(self, **arguments): + # First get embeddings for the query text + embeddings_client = self.context("embeddings-request") + logger.debug("Getting embeddings for row query...") + + query_text = arguments.get("query") + vectors = await embeddings_client.embed(query_text) + + # Now query row embeddings + client = self.context("row-embeddings-query-request") + logger.debug("Row embeddings query...") + + # Get user from client context if available + user = getattr(client, '_current_user', self.user or "trustgraph") + + matches = await client.row_embeddings_query( + vectors=vectors, + schema_name=self.schema_name, + user=user, + collection=self.collection or "default", + index_name=self.index_name, + limit=self.limit + ) + + # Format results for agent consumption + if not matches: + return "No matching records found" + + results = [] + for match in matches: + result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})" + results.append(result) + + return "Matching records:\n" + "\n".join(results) + + # This tool implementation knows how to execute prompt templates class PromptImpl: def __init__(self, context, template_id, arguments=None): diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 18154fc5..b8cc5f9e 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -124,12 +124,13 @@ class Processor(AsyncProcessor): logger.info(f"Configuration version: {version}") - if "flows" in config: - + if "flow" in config: self.flows = { k: json.loads(v) - for k, v in config["flows"].items() + for k, v in config["flow"].items() } + else: + self.flows = {} logger.debug(f"Flows: {self.flows}") diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 116abe02..61639096 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -11,7 +11,29 @@ _active_clusters = [] logger = logging.getLogger(__name__) +# Sentinel value for wildcard graph queries +GRAPH_WILDCARD = "*" + +# Default graph stored as empty string +DEFAULT_GRAPH = "" + + class KnowledgeGraph: + """ + REDUNDANT: This 7-table implementation has been superseded by + EntityCentricKnowledgeGraph which uses a more efficient 2-table model. + This class is retained temporarily for reference but should not be used + for new deployments. + + Cassandra-backed knowledge graph supporting quads (s, p, o, g). + + Uses 7 tables to support all 16 query patterns efficiently: + - Family A (g-wildcard): SPOG, POSG, OSPG + - Family B (g-specified): GSPO, GPOS, GOSP + - Collection table: COLL (for iteration/deletion) + + Plus a metadata table for tracking collections. + """ def __init__( self, hosts=None, @@ -24,12 +46,22 @@ class KnowledgeGraph: self.keyspace = keyspace self.username = username - # Optimized multi-table schema with collection deletion support - self.subject_table = "triples_s" - self.po_table = "triples_p" - self.object_table = "triples_o" - self.collection_table = "triples_collection" # For SPO queries and deletion - self.collection_metadata_table = "collection_metadata" # For tracking which collections exist + # 7-table schema for quads with full query pattern support + # Family A: g-wildcard queries (g in clustering columns) + self.spog_table = "quads_spog" # partition (collection, s), cluster (p, o, g) + self.posg_table = "quads_posg" # partition (collection, p), cluster (o, s, g) + self.ospg_table = "quads_ospg" # partition (collection, o), cluster (s, p, g) + + # Family B: g-specified queries (g in partition key) + self.gspo_table = "quads_gspo" # partition (collection, g, s), cluster (p, o) + self.gpos_table = "quads_gpos" # partition (collection, g, p), cluster (o, s) + self.gosp_table = "quads_gosp" # partition (collection, g, o), cluster (s, p) + + # Collection table for iteration and bulk deletion + self.coll_table = "quads_coll" # partition (collection), cluster (g, s, p, o) + + # Collection metadata tracking + self.collection_metadata_table = "collection_metadata" if username and password: ssl_context = SSLContext(PROTOCOL_TLSv1_2) @@ -46,237 +78,376 @@ class KnowledgeGraph: self.prepare_statements() def clear(self): - self.session.execute(f""" drop keyspace if exists {self.keyspace}; - """); - + """) self.init() def init(self): - self.session.execute(f""" create keyspace if not exists {self.keyspace} with replication = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}; - """); + """) self.session.set_keyspace(self.keyspace) - self.init_optimized_schema() + self.init_quad_schema() + def init_quad_schema(self): + """Initialize 7-table schema for quads with full query pattern support""" - def init_optimized_schema(self): - """Initialize optimized multi-table schema for performance""" - # Table 1: Subject-centric queries (get_s, get_sp, get_os) - # Compound partition key for optimal data distribution + # Family A: g-wildcard queries (g in clustering columns) + + # SPOG: partition (collection, s), cluster (p, o, g) + # Supports: (?, s, ?, ?), (?, s, p, ?), (?, s, p, o) self.session.execute(f""" - CREATE TABLE IF NOT EXISTS {self.subject_table} ( + CREATE TABLE IF NOT EXISTS {self.spog_table} ( collection text, s text, p text, o text, - PRIMARY KEY ((collection, s), p, o) + g text, + PRIMARY KEY ((collection, s), p, o, g) ); - """); + """) - # Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING! - # Compound partition key for optimal data distribution + # POSG: partition (collection, p), cluster (o, s, g) + # Supports: (?, ?, p, ?), (?, ?, p, o) self.session.execute(f""" - CREATE TABLE IF NOT EXISTS {self.po_table} ( + CREATE TABLE IF NOT EXISTS {self.posg_table} ( collection text, p text, o text, s text, - PRIMARY KEY ((collection, p), o, s) + g text, + PRIMARY KEY ((collection, p), o, s, g) ); - """); + """) - # Table 3: Object-centric queries (get_o) - # Compound partition key for optimal data distribution + # OSPG: partition (collection, o), cluster (s, p, g) + # Supports: (?, ?, ?, o), (?, s, ?, o) self.session.execute(f""" - CREATE TABLE IF NOT EXISTS {self.object_table} ( + CREATE TABLE IF NOT EXISTS {self.ospg_table} ( collection text, o text, s text, p text, - PRIMARY KEY ((collection, o), s, p) + g text, + PRIMARY KEY ((collection, o), s, p, g) ); - """); + """) - # Table 4: Collection management and SPO queries (get_spo) - # Simple partition key enables efficient collection deletion + # Family B: g-specified queries (g in partition key) + + # GSPO: partition (collection, g, s), cluster (p, o) + # Supports: (g, s, ?, ?), (g, s, p, ?), (g, s, p, o) self.session.execute(f""" - CREATE TABLE IF NOT EXISTS {self.collection_table} ( + CREATE TABLE IF NOT EXISTS {self.gspo_table} ( collection text, + g text, s text, p text, o text, - PRIMARY KEY (collection, s, p, o) + PRIMARY KEY ((collection, g, s), p, o) ); - """); + """) - # Table 5: Collection metadata tracking - # Tracks which collections exist without polluting triple data + # GPOS: partition (collection, g, p), cluster (o, s) + # Supports: (g, ?, p, ?), (g, ?, p, o) + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.gpos_table} ( + collection text, + g text, + p text, + o text, + s text, + PRIMARY KEY ((collection, g, p), o, s) + ); + """) + + # GOSP: partition (collection, g, o), cluster (s, p) + # Supports: (g, ?, ?, o), (g, s, ?, o) + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.gosp_table} ( + collection text, + g text, + o text, + s text, + p text, + PRIMARY KEY ((collection, g, o), s, p) + ); + """) + + # Collection table for iteration and bulk deletion + # COLL: partition (collection), cluster (g, s, p, o) + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.coll_table} ( + collection text, + g text, + s text, + p text, + o text, + PRIMARY KEY (collection, g, s, p, o) + ); + """) + + # Collection metadata tracking self.session.execute(f""" CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} ( collection text, created_at timestamp, PRIMARY KEY (collection) ); - """); + """) - logger.info("Optimized multi-table schema initialized (5 tables)") + logger.info("Quad schema initialized (7 tables + metadata)") def prepare_statements(self): - """Prepare statements for optimal performance""" - # Insert statements for batch operations - self.insert_subject_stmt = self.session.prepare( - f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)" + """Prepare statements for all 7 tables""" + + # Insert statements + self.insert_spog_stmt = self.session.prepare( + f"INSERT INTO {self.spog_table} (collection, s, p, o, g) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_posg_stmt = self.session.prepare( + f"INSERT INTO {self.posg_table} (collection, p, o, s, g) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_ospg_stmt = self.session.prepare( + f"INSERT INTO {self.ospg_table} (collection, o, s, p, g) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_gspo_stmt = self.session.prepare( + f"INSERT INTO {self.gspo_table} (collection, g, s, p, o) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_gpos_stmt = self.session.prepare( + f"INSERT INTO {self.gpos_table} (collection, g, p, o, s) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_gosp_stmt = self.session.prepare( + f"INSERT INTO {self.gosp_table} (collection, g, o, s, p) VALUES (?, ?, ?, ?, ?)" + ) + self.insert_coll_stmt = self.session.prepare( + f"INSERT INTO {self.coll_table} (collection, g, s, p, o) VALUES (?, ?, ?, ?, ?)" ) - self.insert_po_stmt = self.session.prepare( - f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)" + # Delete statements (for single quad deletion) + self.delete_spog_stmt = self.session.prepare( + f"DELETE FROM {self.spog_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? AND g = ?" + ) + self.delete_posg_stmt = self.session.prepare( + f"DELETE FROM {self.posg_table} WHERE collection = ? AND p = ? AND o = ? AND s = ? AND g = ?" + ) + self.delete_ospg_stmt = self.session.prepare( + f"DELETE FROM {self.ospg_table} WHERE collection = ? AND o = ? AND s = ? AND p = ? AND g = ?" + ) + self.delete_gspo_stmt = self.session.prepare( + f"DELETE FROM {self.gspo_table} WHERE collection = ? AND g = ? AND s = ? AND p = ? AND o = ?" + ) + self.delete_gpos_stmt = self.session.prepare( + f"DELETE FROM {self.gpos_table} WHERE collection = ? AND g = ? AND p = ? AND o = ? AND s = ?" + ) + self.delete_gosp_stmt = self.session.prepare( + f"DELETE FROM {self.gosp_table} WHERE collection = ? AND g = ? AND o = ? AND s = ? AND p = ?" + ) + self.delete_coll_stmt = self.session.prepare( + f"DELETE FROM {self.coll_table} WHERE collection = ? AND g = ? AND s = ? AND p = ? AND o = ?" ) - self.insert_object_stmt = self.session.prepare( - f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)" + # Query statements - Family A (g-wildcard, g in clustering) + + # SPOG table queries + self.get_s_wildcard_stmt = self.session.prepare( + f"SELECT p, o, g FROM {self.spog_table} WHERE collection = ? AND s = ? LIMIT ?" + ) + self.get_sp_wildcard_stmt = self.session.prepare( + f"SELECT o, g FROM {self.spog_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?" + ) + self.get_spo_wildcard_stmt = self.session.prepare( + f"SELECT g FROM {self.spog_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?" ) - self.insert_collection_stmt = self.session.prepare( - f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)" + # POSG table queries + self.get_p_wildcard_stmt = self.session.prepare( + f"SELECT o, s, g FROM {self.posg_table} WHERE collection = ? AND p = ? LIMIT ?" + ) + self.get_po_wildcard_stmt = self.session.prepare( + f"SELECT s, g FROM {self.posg_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?" ) - # Query statements for optimized access + # OSPG table queries + self.get_o_wildcard_stmt = self.session.prepare( + f"SELECT s, p, g FROM {self.ospg_table} WHERE collection = ? AND o = ? LIMIT ?" + ) + self.get_os_wildcard_stmt = self.session.prepare( + f"SELECT p, g FROM {self.ospg_table} WHERE collection = ? AND o = ? AND s = ? LIMIT ?" + ) + + # Query statements - Family B (g-specified, g in partition) + + # GSPO table queries + self.get_gs_stmt = self.session.prepare( + f"SELECT p, o FROM {self.gspo_table} WHERE collection = ? AND g = ? AND s = ? LIMIT ?" + ) + self.get_gsp_stmt = self.session.prepare( + f"SELECT o FROM {self.gspo_table} WHERE collection = ? AND g = ? AND s = ? AND p = ? LIMIT ?" + ) + self.get_gspo_stmt = self.session.prepare( + f"SELECT s FROM {self.gspo_table} WHERE collection = ? AND g = ? AND s = ? AND p = ? AND o = ? LIMIT ?" + ) + + # GPOS table queries + self.get_gp_stmt = self.session.prepare( + f"SELECT o, s FROM {self.gpos_table} WHERE collection = ? AND g = ? AND p = ? LIMIT ?" + ) + self.get_gpo_stmt = self.session.prepare( + f"SELECT s FROM {self.gpos_table} WHERE collection = ? AND g = ? AND p = ? AND o = ? LIMIT ?" + ) + + # GOSP table queries + self.get_go_stmt = self.session.prepare( + f"SELECT s, p FROM {self.gosp_table} WHERE collection = ? AND g = ? AND o = ? LIMIT ?" + ) + self.get_gos_stmt = self.session.prepare( + f"SELECT p FROM {self.gosp_table} WHERE collection = ? AND g = ? AND o = ? AND s = ? LIMIT ?" + ) + + # Collection table query (for get_all and iteration) self.get_all_stmt = self.session.prepare( - f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING" + f"SELECT g, s, p, o FROM {self.coll_table} WHERE collection = ? LIMIT ?" + ) + self.get_g_stmt = self.session.prepare( + f"SELECT s, p, o FROM {self.coll_table} WHERE collection = ? AND g = ? LIMIT ?" ) - self.get_s_stmt = self.session.prepare( - f"SELECT p, o FROM {self.subject_table} WHERE collection = ? AND s = ? LIMIT ?" - ) + logger.info("Prepared statements initialized for quad schema (7 tables)") - self.get_p_stmt = self.session.prepare( - f"SELECT s, o FROM {self.po_table} WHERE collection = ? AND p = ? LIMIT ?" - ) + def insert(self, collection, s, p, o, g=None): + """Insert a quad into all 7 tables""" + # Default graph stored as empty string + if g is None: + g = DEFAULT_GRAPH - self.get_o_stmt = self.session.prepare( - f"SELECT s, p FROM {self.object_table} WHERE collection = ? AND o = ? LIMIT ?" - ) - - self.get_sp_stmt = self.session.prepare( - f"SELECT o FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?" - ) - - # The critical optimization: get_po without ALLOW FILTERING! - self.get_po_stmt = self.session.prepare( - f"SELECT s FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?" - ) - - self.get_os_stmt = self.session.prepare( - f"SELECT p FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? LIMIT ?" - ) - - self.get_spo_stmt = self.session.prepare( - f"SELECT s as x FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?" - ) - - # Delete statements for collection deletion - self.delete_subject_stmt = self.session.prepare( - f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?" - ) - - self.delete_po_stmt = self.session.prepare( - f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?" - ) - - self.delete_object_stmt = self.session.prepare( - f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?" - ) - - self.delete_collection_stmt = self.session.prepare( - f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?" - ) - - logger.info("Prepared statements initialized for optimal performance (4 tables)") - - def insert(self, collection, s, p, o): - # Batch write to all four tables for consistency batch = BatchStatement() - # Insert into subject table - batch.add(self.insert_subject_stmt, (collection, s, p, o)) + # Family A tables + batch.add(self.insert_spog_stmt, (collection, s, p, o, g)) + batch.add(self.insert_posg_stmt, (collection, p, o, s, g)) + batch.add(self.insert_ospg_stmt, (collection, o, s, p, g)) - # Insert into predicate-object table (column order: collection, p, o, s) - batch.add(self.insert_po_stmt, (collection, p, o, s)) + # Family B tables + batch.add(self.insert_gspo_stmt, (collection, g, s, p, o)) + batch.add(self.insert_gpos_stmt, (collection, g, p, o, s)) + batch.add(self.insert_gosp_stmt, (collection, g, o, s, p)) - # Insert into object table (column order: collection, o, s, p) - batch.add(self.insert_object_stmt, (collection, o, s, p)) - - # Insert into collection table for SPO queries and deletion tracking - batch.add(self.insert_collection_stmt, (collection, s, p, o)) + # Collection table + batch.add(self.insert_coll_stmt, (collection, g, s, p, o)) self.session.execute(batch) + def delete_quad(self, collection, s, p, o, g=None): + """Delete a single quad from all 7 tables""" + if g is None: + g = DEFAULT_GRAPH + + batch = BatchStatement() + + batch.add(self.delete_spog_stmt, (collection, s, p, o, g)) + batch.add(self.delete_posg_stmt, (collection, p, o, s, g)) + batch.add(self.delete_ospg_stmt, (collection, o, s, p, g)) + batch.add(self.delete_gspo_stmt, (collection, g, s, p, o)) + batch.add(self.delete_gpos_stmt, (collection, g, p, o, s)) + batch.add(self.delete_gosp_stmt, (collection, g, o, s, p)) + batch.add(self.delete_coll_stmt, (collection, g, s, p, o)) + + self.session.execute(batch) + + # ======================================================================== + # Query methods + # g=None means default graph, g="*" means all graphs + # ======================================================================== + def get_all(self, collection, limit=50): - # Use subject table for get_all queries - return self.session.execute( - self.get_all_stmt, - (collection, limit) - ) + """Get all quads in collection""" + return self.session.execute(self.get_all_stmt, (collection, limit)) - def get_s(self, collection, s, limit=10): - # Optimized: Direct partition access with (collection, s) - return self.session.execute( - self.get_s_stmt, - (collection, s, limit) - ) + def get_s(self, collection, s, g=None, limit=10): + """Query by subject. g=None: default graph, g='*': all graphs""" + if g is None or g == DEFAULT_GRAPH: + # Default graph - use GSPO table + return self.session.execute(self.get_gs_stmt, (collection, DEFAULT_GRAPH, s, limit)) + elif g == GRAPH_WILDCARD: + # All graphs - use SPOG table + return self.session.execute(self.get_s_wildcard_stmt, (collection, s, limit)) + else: + # Specific graph - use GSPO table + return self.session.execute(self.get_gs_stmt, (collection, g, s, limit)) - def get_p(self, collection, p, limit=10): - # Optimized: Use po_table for direct partition access - return self.session.execute( - self.get_p_stmt, - (collection, p, limit) - ) + def get_p(self, collection, p, g=None, limit=10): + """Query by predicate""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_gp_stmt, (collection, DEFAULT_GRAPH, p, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_p_wildcard_stmt, (collection, p, limit)) + else: + return self.session.execute(self.get_gp_stmt, (collection, g, p, limit)) - def get_o(self, collection, o, limit=10): - # Optimized: Use object_table for direct partition access - return self.session.execute( - self.get_o_stmt, - (collection, o, limit) - ) + def get_o(self, collection, o, g=None, limit=10): + """Query by object""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_go_stmt, (collection, DEFAULT_GRAPH, o, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_o_wildcard_stmt, (collection, o, limit)) + else: + return self.session.execute(self.get_go_stmt, (collection, g, o, limit)) - def get_sp(self, collection, s, p, limit=10): - # Optimized: Use subject_table with clustering key access - return self.session.execute( - self.get_sp_stmt, - (collection, s, p, limit) - ) + def get_sp(self, collection, s, p, g=None, limit=10): + """Query by subject and predicate""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_gsp_stmt, (collection, DEFAULT_GRAPH, s, p, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_sp_wildcard_stmt, (collection, s, p, limit)) + else: + return self.session.execute(self.get_gsp_stmt, (collection, g, s, p, limit)) - def get_po(self, collection, p, o, limit=10): - # CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING! - return self.session.execute( - self.get_po_stmt, - (collection, p, o, limit) - ) + def get_po(self, collection, p, o, g=None, limit=10): + """Query by predicate and object""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_gpo_stmt, (collection, DEFAULT_GRAPH, p, o, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_po_wildcard_stmt, (collection, p, o, limit)) + else: + return self.session.execute(self.get_gpo_stmt, (collection, g, p, o, limit)) - def get_os(self, collection, o, s, limit=10): - # Optimized: Use subject_table with clustering access (no more ALLOW FILTERING) - return self.session.execute( - self.get_os_stmt, - (collection, s, o, limit) - ) + def get_os(self, collection, o, s, g=None, limit=10): + """Query by object and subject""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_gos_stmt, (collection, DEFAULT_GRAPH, o, s, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_os_wildcard_stmt, (collection, o, s, limit)) + else: + return self.session.execute(self.get_gos_stmt, (collection, g, o, s, limit)) - def get_spo(self, collection, s, p, o, limit=10): - # Optimized: Use collection_table for exact key lookup - return self.session.execute( - self.get_spo_stmt, - (collection, s, p, o, limit) - ) + def get_spo(self, collection, s, p, o, g=None, limit=10): + """Query by subject, predicate, object (find which graphs)""" + if g is None or g == DEFAULT_GRAPH: + return self.session.execute(self.get_gspo_stmt, (collection, DEFAULT_GRAPH, s, p, o, limit)) + elif g == GRAPH_WILDCARD: + return self.session.execute(self.get_spo_wildcard_stmt, (collection, s, p, o, limit)) + else: + return self.session.execute(self.get_gspo_stmt, (collection, g, s, p, o, limit)) + + def get_g(self, collection, g, limit=50): + """Get all quads in a specific graph""" + if g is None: + g = DEFAULT_GRAPH + return self.session.execute(self.get_g_stmt, (collection, g, limit)) + + # ======================================================================== + # Collection management + # ======================================================================== def collection_exists(self, collection): - """Check if collection exists by querying collection_metadata table""" + """Check if collection exists""" try: result = self.session.execute( f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", @@ -301,63 +472,624 @@ class KnowledgeGraph: raise e def delete_collection(self, collection): - """Delete all triples for a specific collection - - Uses collection_table to enumerate all triples, then deletes from all 4 tables - using full partition keys for optimal performance with compound keys. - """ - # Step 1: Read all triples from collection_table (single partition read) + """Delete all quads for a collection from all 7 tables""" + # Read all quads from collection table rows = self.session.execute( - f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s", + f"SELECT g, s, p, o FROM {self.coll_table} WHERE collection = %s", (collection,) ) - # Step 2: Delete each triple from all 4 tables using full partition keys - # Batch deletions for efficiency batch = BatchStatement() count = 0 for row in rows: - s, p, o = row.s, row.p, row.o + g, s, p, o = row.g, row.s, row.p, row.o - # Delete from subject table (partition key: collection, s) - batch.add(self.delete_subject_stmt, (collection, s, p, o)) - - # Delete from predicate-object table (partition key: collection, p) - batch.add(self.delete_po_stmt, (collection, p, o, s)) - - # Delete from object table (partition key: collection, o) - batch.add(self.delete_object_stmt, (collection, o, s, p)) - - # Delete from collection table (partition key: collection only) - batch.add(self.delete_collection_stmt, (collection, s, p, o)) + # Delete from all 7 tables + batch.add(self.delete_spog_stmt, (collection, s, p, o, g)) + batch.add(self.delete_posg_stmt, (collection, p, o, s, g)) + batch.add(self.delete_ospg_stmt, (collection, o, s, p, g)) + batch.add(self.delete_gspo_stmt, (collection, g, s, p, o)) + batch.add(self.delete_gpos_stmt, (collection, g, p, o, s)) + batch.add(self.delete_gosp_stmt, (collection, g, o, s, p)) + batch.add(self.delete_coll_stmt, (collection, g, s, p, o)) count += 1 - # Execute batch every 25 triples to avoid oversized batches - # (Each triple adds ~4 statements, so 25 triples = ~100 statements) - if count % 25 == 0: + # Execute batch every 15 quads (7 deletes each = 105 statements) + if count % 15 == 0: self.session.execute(batch) batch = BatchStatement() - # Execute remaining deletions - if count % 25 != 0: + # Execute remaining + if count % 15 != 0: self.session.execute(batch) - # Step 3: Delete collection metadata + # Delete collection metadata self.session.execute( f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", (collection,) ) - logger.info(f"Deleted {count} triples from collection {collection}") + logger.info(f"Deleted {count} quads from collection {collection}") def close(self): - """Close the Cassandra session and cluster connections properly""" + """Close connections""" if hasattr(self, 'session') and self.session: self.session.shutdown() if hasattr(self, 'cluster') and self.cluster: self.cluster.shutdown() - # Remove from global tracking if self.cluster in _active_clusters: _active_clusters.remove(self.cluster) + + +class EntityCentricKnowledgeGraph: + """ + Entity-centric Cassandra-backed knowledge graph supporting quads (s, p, o, g). + + Uses 2 tables instead of 7: + - quads_by_entity: every entity knows every quad it participates in + - quads_by_collection: manifest for collection-level queries and deletion + + Supports all 16 query patterns with single-partition reads. + """ + + def __init__( + self, hosts=None, + keyspace="trustgraph", username=None, password=None + ): + + if hosts is None: + hosts = ["localhost"] + + self.keyspace = keyspace + self.username = username + + # 2-table entity-centric schema + self.entity_table = "quads_by_entity" + self.collection_table = "quads_by_collection" + + # Collection metadata tracking + self.collection_metadata_table = "collection_metadata" + + if username and password: + ssl_context = SSLContext(PROTOCOL_TLSv1_2) + auth_provider = PlainTextAuthProvider(username=username, password=password) + self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) + else: + self.cluster = Cluster(hosts) + self.session = self.cluster.connect() + + # Track this cluster globally + _active_clusters.append(self.cluster) + + self.init() + self.prepare_statements() + + def clear(self): + self.session.execute(f""" + drop keyspace if exists {self.keyspace}; + """) + self.init() + + def init(self): + self.session.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }}; + """) + + self.session.set_keyspace(self.keyspace) + self.init_entity_centric_schema() + + def init_entity_centric_schema(self): + """Initialize 2-table entity-centric schema""" + + # quads_by_entity: primary data table + # Every entity has a partition containing all quads it participates in + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.entity_table} ( + collection text, + entity text, + role text, + p text, + otype text, + s text, + o text, + d text, + dtype text, + lang text, + PRIMARY KEY ((collection, entity), role, p, otype, s, o, d) + ); + """) + + # quads_by_collection: manifest for collection-level queries and deletion + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.collection_table} ( + collection text, + d text, + s text, + p text, + o text, + otype text, + dtype text, + lang text, + PRIMARY KEY (collection, d, s, p, o) + ); + """) + + # Collection metadata tracking + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} ( + collection text, + created_at timestamp, + PRIMARY KEY (collection) + ); + """) + + logger.info("Entity-centric schema initialized (2 tables + metadata)") + + def prepare_statements(self): + """Prepare statements for entity-centric schema""" + + # Insert statement for quads_by_entity + self.insert_entity_stmt = self.session.prepare( + f"INSERT INTO {self.entity_table} " + "(collection, entity, role, p, otype, s, o, d, dtype, lang) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + ) + + # Insert statement for quads_by_collection + self.insert_collection_stmt = self.session.prepare( + f"INSERT INTO {self.collection_table} " + "(collection, d, s, p, o, otype, dtype, lang) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + ) + + # Query statements for quads_by_entity + + # Get all quads for an entity (any role) + self.get_entity_all_stmt = self.session.prepare( + f"SELECT role, p, otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? LIMIT ?" + ) + + # Get quads where entity is subject (role='S') + self.get_entity_as_s_stmt = self.session.prepare( + f"SELECT p, otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'S' LIMIT ?" + ) + + # Get quads where entity is subject with specific predicate + self.get_entity_as_s_p_stmt = self.session.prepare( + f"SELECT otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'S' AND p = ? LIMIT ?" + ) + + # Get quads where entity is subject with specific predicate and otype + self.get_entity_as_s_p_otype_stmt = self.session.prepare( + f"SELECT s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'S' AND p = ? AND otype = ? LIMIT ?" + ) + + # Get quads where entity is predicate (role='P') + self.get_entity_as_p_stmt = self.session.prepare( + f"SELECT p, otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'P' LIMIT ?" + ) + + # Get quads where entity is object (role='O') + self.get_entity_as_o_stmt = self.session.prepare( + f"SELECT p, otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'O' LIMIT ?" + ) + + # Get quads where entity is object with specific predicate + self.get_entity_as_o_p_stmt = self.session.prepare( + f"SELECT otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'O' AND p = ? LIMIT ?" + ) + + # Get quads where entity is graph (role='G') + self.get_entity_as_g_stmt = self.session.prepare( + f"SELECT p, otype, s, o, d, dtype, lang FROM {self.entity_table} " + "WHERE collection = ? AND entity = ? AND role = 'G' LIMIT ?" + ) + + # Query statements for quads_by_collection + + # Get all quads in collection + self.get_collection_all_stmt = self.session.prepare( + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} " + "WHERE collection = ? LIMIT ?" + ) + + # Get all quads in a specific graph + self.get_collection_by_graph_stmt = self.session.prepare( + f"SELECT s, p, o, otype, dtype, lang FROM {self.collection_table} " + "WHERE collection = ? AND d = ? LIMIT ?" + ) + + # Delete statements + self.delete_entity_partition_stmt = self.session.prepare( + f"DELETE FROM {self.entity_table} WHERE collection = ? AND entity = ?" + ) + + self.delete_collection_row_stmt = self.session.prepare( + f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ?" + ) + + logger.info("Prepared statements initialized for entity-centric schema") + + def insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""): + """ + Insert a quad into entity-centric tables. + + Writes 4 rows to quads_by_entity (one for each entity role) + 1 row to + quads_by_collection. For literals, only 3 entity rows are written since + literals are not independently queryable entities. + + Args: + collection: Collection/tenant scope + s: Subject (string value) + p: Predicate (string value) + o: Object (string value) + g: Graph/dataset (None for default graph) + otype: Object type - 'u' (URI), 'l' (literal), 't' (triple) + Auto-detected from o value if not provided + dtype: XSD datatype (for literals) + lang: Language tag (for literals) + """ + # Default graph stored as empty string + if g is None: + g = DEFAULT_GRAPH + + # Auto-detect otype if not provided (backwards compatibility) + if otype is None: + if o.startswith("http://") or o.startswith("https://"): + otype = "u" + else: + otype = "l" + + batch = BatchStatement() + + # Write row for subject entity (role='S') + batch.add(self.insert_entity_stmt, ( + collection, s, 'S', p, otype, s, o, g, dtype, lang + )) + + # Write row for predicate entity (role='P') + batch.add(self.insert_entity_stmt, ( + collection, p, 'P', p, otype, s, o, g, dtype, lang + )) + + # Write row for object entity (role='O') - only for URIs, not literals + if otype == 'u' or otype == 't': + batch.add(self.insert_entity_stmt, ( + collection, o, 'O', p, otype, s, o, g, dtype, lang + )) + + # Write row for graph entity (role='G') - only for non-default graphs + if g != DEFAULT_GRAPH: + batch.add(self.insert_entity_stmt, ( + collection, g, 'G', p, otype, s, o, g, dtype, lang + )) + + # Write row to quads_by_collection + batch.add(self.insert_collection_stmt, ( + collection, g, s, p, o, otype, dtype, lang + )) + + self.session.execute(batch) + + # ======================================================================== + # Query methods + # g=None means default graph, g="*" means all graphs + # Results include otype, dtype, lang for proper Term reconstruction + # ======================================================================== + + def get_all(self, collection, limit=50): + """Get all quads in collection""" + return self.session.execute(self.get_collection_all_stmt, (collection, limit)) + + def get_s(self, collection, s, g=None, limit=10): + """ + Query by subject. Returns quads where s is the subject. + g=None: default graph, g='*': all graphs + """ + rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit)) + + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + # Filter by graph if specified + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_p(self, collection, p, g=None, limit=10): + """Query by predicate""" + rows = self.session.execute(self.get_entity_as_p_stmt, (collection, p, limit)) + + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_o(self, collection, o, g=None, limit=10): + """Query by object""" + rows = self.session.execute(self.get_entity_as_o_stmt, (collection, o, limit)) + + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_sp(self, collection, s, p, g=None, limit=10): + """Query by subject and predicate""" + rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit)) + + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=s, p=p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_po(self, collection, p, o, g=None, limit=10): + """Query by predicate and object""" + rows = self.session.execute(self.get_entity_as_o_p_stmt, (collection, o, p, limit)) + + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=row.s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_os(self, collection, o, s, g=None, limit=10): + """Query by object and subject""" + # Use subject partition with role='S', filter by o + rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit)) + + results = [] + for row in rows: + if row.o != o: + continue + + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=s, p=row.p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_spo(self, collection, s, p, o, g=None, limit=10): + """Query by subject, predicate, object (find which graphs)""" + rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit)) + + results = [] + for row in rows: + if row.o != o: + continue + + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is None or g == DEFAULT_GRAPH: + if d != DEFAULT_GRAPH: + continue + elif g != GRAPH_WILDCARD and d != g: + continue + + results.append(QuadResult( + s=s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + + return results + + def get_g(self, collection, g, limit=50): + """Get all quads in a specific graph""" + if g is None: + g = DEFAULT_GRAPH + + return self.session.execute(self.get_collection_by_graph_stmt, (collection, g, limit)) + + # ======================================================================== + # Collection management + # ======================================================================== + + def collection_exists(self, collection): + """Check if collection exists""" + try: + result = self.session.execute( + f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", + (collection,) + ) + return bool(list(result)) + except Exception as e: + logger.error(f"Error checking collection existence: {e}") + return False + + def create_collection(self, collection): + """Create collection by inserting metadata row""" + try: + import datetime + self.session.execute( + f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", + (collection, datetime.datetime.now()) + ) + logger.info(f"Created collection metadata for {collection}") + except Exception as e: + logger.error(f"Error creating collection: {e}") + raise e + + def delete_collection(self, collection): + """ + Delete all quads for a collection from both tables. + + Uses efficient partition-level deletes: + 1. Read quads from quads_by_collection to get all quads + 2. Extract unique entities (s, p, o for URIs, g for non-default) + 3. Delete entire entity partitions + 4. Delete collection rows + """ + # Read all quads from collection table + rows = self.session.execute( + f"SELECT d, s, p, o, otype FROM {self.collection_table} WHERE collection = %s", + (collection,) + ) + + # Collect unique entities and quad data for deletion + entities = set() + quads = [] + + for row in rows: + d, s, p, o, otype = row.d, row.s, row.p, row.o, row.otype + quads.append((d, s, p, o)) + + # Subject and predicate are always entities + entities.add(s) + entities.add(p) + + # Object is an entity only for URIs + if otype == 'u' or otype == 't': + entities.add(o) + + # Graph is an entity for non-default graphs + if d != DEFAULT_GRAPH: + entities.add(d) + + # Delete entity partitions (efficient partition-level deletes) + batch = BatchStatement() + count = 0 + + for entity in entities: + batch.add(self.delete_entity_partition_stmt, (collection, entity)) + count += 1 + + # Execute batch every 50 entities + if count % 50 == 0: + self.session.execute(batch) + batch = BatchStatement() + + # Execute remaining entity deletes + if count % 50 != 0: + self.session.execute(batch) + + # Delete collection rows + batch = BatchStatement() + count = 0 + + for d, s, p, o in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o)) + count += 1 + + # Execute batch every 50 quads + if count % 50 == 0: + self.session.execute(batch) + batch = BatchStatement() + + # Execute remaining collection row deletes + if count % 50 != 0: + self.session.execute(batch) + + # Delete collection metadata + self.session.execute( + f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", + (collection,) + ) + + logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + + def close(self): + """Close connections""" + if hasattr(self, 'session') and self.session: + self.session.shutdown() + if hasattr(self, 'cluster') and self.cluster: + self.cluster.shutdown() + if self.cluster in _active_clusters: + _active_clusters.remove(self.cluster) + + +class QuadResult: + """ + Result object for quad queries, including object type metadata. + + Attributes: + s: Subject value + p: Predicate value + o: Object value + g: Graph/dataset value + otype: Object type - 'u' (URI), 'l' (literal), 't' (triple) + dtype: XSD datatype (for literals) + lang: Language tag (for literals) + """ + + def __init__(self, s, p, o, g, otype='u', dtype='', lang=''): + self.s = s + self.p = p + self.o = o + self.g = g + self.otype = otype + self.dtype = dtype + self.lang = lang + + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index 4726be4d..1b63774d 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -16,12 +16,14 @@ import logging logger = logging.getLogger(__name__) default_ident = "graph-embeddings" +default_batch_size = 5 class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id") + self.batch_size = params.get("batch_size", default_batch_size) super(Processor, self).__init__( **params | { @@ -73,12 +75,14 @@ class Processor(FlowProcessor): ) ) - r = GraphEmbeddings( - metadata=v.metadata, - entities=entities, - ) - - await flow("output").send(r) + # Send in batches to avoid oversized messages + for i in range(0, len(entities), self.batch_size): + batch = entities[i:i + self.batch_size] + r = GraphEmbeddings( + metadata=v.metadata, + entities=batch, + ) + await flow("output").send(r) except Exception as e: logger.error("Exception occurred", exc_info=True) @@ -91,6 +95,13 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '--batch-size', + type=int, + default=default_batch_size, + help=f'Maximum entities per output message (default: {default_batch_size})' + ) + FlowProcessor.add_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py new file mode 100644 index 00000000..a48cc4d0 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py @@ -0,0 +1,6 @@ + +from . embeddings import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py new file mode 100644 index 00000000..84c41ff3 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -0,0 +1,263 @@ + +""" +Row embeddings processor. Calls the embeddings service to compute embeddings +for indexed field values in extracted row data. + +Input is ExtractedObject (structured row data with schema). +Output is RowEmbeddings (row data with embeddings for indexed fields). + +This follows the two-stage pattern used by graph-embeddings and document-embeddings: + Stage 1 (this processor): Compute embeddings + Stage 2 (row-embeddings-write-*): Store embeddings +""" + +import json +import logging +from typing import Dict, List, Set + +from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding +from ... schema import RowSchema, Field +from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec +from ... base import ProducerSpec, CollectionConfigHandler + +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings" +default_batch_size = 10 + + +class Processor(CollectionConfigHandler, FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + self.batch_size = params.get("batch_size", default_batch_size) + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + + super(Processor, self).__init__( + **params | { + "id": id, + "config_type": self.config_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="input", + schema=ExtractedObject, + handler=self.on_message, + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name="embeddings-request", + response_name="embeddings-response", + ) + ) + + self.register_specification( + ProducerSpec( + name="output", + schema=RowEmbeddings + ) + ) + + # Register config handlers + self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_collection_config) + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") + + # Clear existing schemas + self.schemas = {} + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = Field( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + def get_index_names(self, schema: RowSchema) -> List[str]: + """Get all index names for a schema.""" + index_names = [] + for field in schema.fields: + if field.primary or field.indexed: + index_names.append(field.name) + return index_names + + def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]: + """Build the index_value list for a given index.""" + field_names = [f.strip() for f in index_name.split(',')] + values = [] + for field_name in field_names: + value = value_map.get(field_name) + values.append(str(value) if value is not None else "") + return values + + def build_text_for_embedding(self, index_value: List[str]) -> str: + """Build text representation for embedding from index values.""" + # Space-join the values for composite indexes + return " ".join(index_value) + + async def on_message(self, msg, consumer, flow): + """Process incoming ExtractedObject and compute embeddings""" + + obj = msg.value() + logger.info( + f"Computing embeddings for {len(obj.values)} rows, " + f"schema {obj.schema_name}, doc {obj.metadata.id}" + ) + + # Validate collection exists before processing + if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + logger.warning( + f"Collection {obj.metadata.collection} for user {obj.metadata.user} " + f"does not exist in config. Dropping message." + ) + return + + # Get schema definition + schema = self.schemas.get(obj.schema_name) + if not schema: + logger.warning(f"No schema found for {obj.schema_name} - skipping") + return + + # Get all index names for this schema + index_names = self.get_index_names(schema) + + if not index_names: + logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping") + return + + # Track unique texts to avoid duplicate embeddings + # text -> (index_name, index_value) + texts_to_embed: Dict[str, tuple] = {} + + # Collect all texts that need embeddings + for value_map in obj.values: + for index_name in index_names: + index_value = self.build_index_value(value_map, index_name) + + # Skip empty values + if not index_value or all(v == "" for v in index_value): + continue + + text = self.build_text_for_embedding(index_value) + if text and text not in texts_to_embed: + texts_to_embed[text] = (index_name, index_value) + + if not texts_to_embed: + logger.info("No texts to embed") + return + + # Compute embeddings + embeddings_list = [] + + try: + for text, (index_name, index_value) in texts_to_embed.items(): + vectors = await flow("embeddings-request").embed(text=text) + + embeddings_list.append( + RowIndexEmbedding( + index_name=index_name, + index_value=index_value, + text=text, + vectors=vectors + ) + ) + + # Send in batches to avoid oversized messages + for i in range(0, len(embeddings_list), self.batch_size): + batch = embeddings_list[i:i + self.batch_size] + result = RowEmbeddings( + metadata=obj.metadata, + schema_name=obj.schema_name, + embeddings=batch, + ) + await flow("output").send(result) + + logger.info( + f"Computed {len(embeddings_list)} embeddings for " + f"{len(obj.values)} rows ({len(index_names)} indexes)" + ) + + except Exception as e: + logger.error("Exception during embedding computation", exc_info=True) + raise e + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Collection creation notification - no action needed for embedding stage""" + logger.debug(f"Row embeddings collection notification for {user}/{collection}") + + async def delete_collection(self, user: str, collection: str): + """Collection deletion notification - no action needed for embedding stage""" + logger.debug(f"Row embeddings collection delete notification for {user}/{collection}") + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + + parser.add_argument( + '--batch-size', + type=int, + default=default_batch_size, + help=f'Maximum embeddings per output message (default: {default_batch_size})' + ) + + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + +def run(): + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index b7ef9259..d9057909 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -3,7 +3,7 @@ import json import urllib.parse import logging -from ....schema import Chunk, Triple, Triples, Metadata, Value +from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from ....schema import EntityContext, EntityContexts from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION @@ -126,16 +126,42 @@ class Processor(FlowProcessor): await pub.send(ecs) - def parse_json(self, text): - json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL) - - if json_match: - json_str = json_match.group(1).strip() - else: - # If no delimiters, assume the entire output is JSON - json_str = text.strip() + def parse_jsonl(self, text): + """ + Parse JSONL response, returning list of valid objects. - return json.loads(json_str) + Invalid lines (malformed JSON, empty lines) are skipped with warnings. + This provides truncation resilience - partial output yields partial results. + """ + results = [] + + # Strip markdown code fences if present + text = text.strip() + if text.startswith('```'): + # Remove opening fence (possibly with language hint) + text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text) + if text.endswith('```'): + text = text[:-3] + + for line_num, line in enumerate(text.strip().split('\n'), 1): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Skip any remaining fence markers + if line.startswith('```'): + continue + + try: + obj = json.loads(line) + results.append(obj) + except json.JSONDecodeError as e: + # Log warning but continue - this provides truncation resilience + logger.warning(f"JSONL parse error on line {line_num}: {e}") + + return results async def on_message(self, msg, consumer, flow): @@ -178,11 +204,12 @@ class Processor(FlowProcessor): question = prompt ) - # Parse JSON response - try: - extraction_data = self.parse_json(agent_response) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON response from agent: {e}") + # Parse JSONL response + extraction_data = self.parse_jsonl(agent_response) + + if not extraction_data: + logger.warning("JSONL parse returned no valid objects") + return # Process extraction data triples, entity_contexts = self.process_extraction_data( @@ -209,103 +236,113 @@ class Processor(FlowProcessor): raise def process_extraction_data(self, data, metadata): - """Process combined extraction data to generate triples and entity contexts""" + """Process JSONL extraction data to generate triples and entity contexts. + + Data is a flat list of objects with 'type' discriminator field: + - {"type": "definition", "entity": "...", "definition": "..."} + - {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool} + """ triples = [] entity_contexts = [] + # Categorize items by type + definitions = [item for item in data if item.get("type") == "definition"] + relationships = [item for item in data if item.get("type") == "relationship"] + # Process definitions - for defn in data.get("definitions", []): + for defn in definitions: entity_uri = self.to_uri(defn["entity"]) - + # Add entity label triples.append(Triple( - s = Value(value=entity_uri, is_uri=True), - p = Value(value=RDF_LABEL, is_uri=True), - o = Value(value=defn["entity"], is_uri=False), + s = Term(type=IRI, iri=entity_uri), + p = Term(type=IRI, iri=RDF_LABEL), + o = Term(type=LITERAL, value=defn["entity"]), )) - + # Add definition triples.append(Triple( - s = Value(value=entity_uri, is_uri=True), - p = Value(value=DEFINITION, is_uri=True), - o = Value(value=defn["definition"], is_uri=False), + s = Term(type=IRI, iri=entity_uri), + p = Term(type=IRI, iri=DEFINITION), + o = Term(type=LITERAL, value=defn["definition"]), )) - + # Add subject-of relationship to document if metadata.id: triples.append(Triple( - s = Value(value=entity_uri, is_uri=True), - p = Value(value=SUBJECT_OF, is_uri=True), - o = Value(value=metadata.id, is_uri=True), + s = Term(type=IRI, iri=entity_uri), + p = Term(type=IRI, iri=SUBJECT_OF), + o = Term(type=IRI, iri=metadata.id), )) - + # Create entity context for embeddings entity_contexts.append(EntityContext( - entity=Value(value=entity_uri, is_uri=True), + entity=Term(type=IRI, iri=entity_uri), context=defn["definition"] )) # Process relationships - for rel in data.get("relationships", []): + for rel in relationships: subject_uri = self.to_uri(rel["subject"]) predicate_uri = self.to_uri(rel["predicate"]) - subject_value = Value(value=subject_uri, is_uri=True) - predicate_value = Value(value=predicate_uri, is_uri=True) - if data.get("object-entity", False): - object_value = Value(value=predicate_uri, is_uri=True) + subject_value = Term(type=IRI, iri=subject_uri) + predicate_value = Term(type=IRI, iri=predicate_uri) + if rel.get("object-entity", True): + object_uri = self.to_uri(rel["object"]) + object_value = Term(type=IRI, iri=object_uri) else: - object_value = Value(value=predicate_uri, is_uri=False) - + object_value = Term(type=LITERAL, value=rel["object"]) + # Add subject and predicate labels triples.append(Triple( s = subject_value, - p = Value(value=RDF_LABEL, is_uri=True), - o = Value(value=rel["subject"], is_uri=False), + p = Term(type=IRI, iri=RDF_LABEL), + o = Term(type=LITERAL, value=rel["subject"]), )) - + triples.append(Triple( s = predicate_value, - p = Value(value=RDF_LABEL, is_uri=True), - o = Value(value=rel["predicate"], is_uri=False), + p = Term(type=IRI, iri=RDF_LABEL), + o = Term(type=LITERAL, value=rel["predicate"]), )) - + # Handle object (entity vs literal) if rel.get("object-entity", True): triples.append(Triple( s = object_value, - p = Value(value=RDF_LABEL, is_uri=True), - o = Value(value=rel["object"], is_uri=True), + p = Term(type=IRI, iri=RDF_LABEL), + o = Term(type=LITERAL, value=rel["object"]), )) - + # Add the main relationship triple triples.append(Triple( s = subject_value, p = predicate_value, o = object_value )) - + # Add subject-of relationships to document if metadata.id: triples.append(Triple( s = subject_value, - p = Value(value=SUBJECT_OF, is_uri=True), - o = Value(value=metadata.id, is_uri=True), + p = Term(type=IRI, iri=SUBJECT_OF), + o = Term(type=IRI, iri=metadata.id), )) - + triples.append(Triple( s = predicate_value, - p = Value(value=SUBJECT_OF, is_uri=True), - o = Value(value=metadata.id, is_uri=True), + p = Term(type=IRI, iri=SUBJECT_OF), + o = Term(type=IRI, iri=metadata.id), )) - + if rel.get("object-entity", True): triples.append(Triple( s = object_value, - p = Value(value=SUBJECT_OF, is_uri=True), - o = Value(value=metadata.id, is_uri=True), + p = Term(type=IRI, iri=SUBJECT_OF), + o = Term(type=IRI, iri=metadata.id), )) return triples, entity_contexts diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 1d414b7e..72275a8c 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -9,7 +9,7 @@ import json import urllib.parse import logging -from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL # Module logger logger = logging.getLogger(__name__) @@ -20,12 +20,14 @@ from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import PromptClientSpec -DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) -RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) -SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) +DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION) +RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL) +SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF) default_ident = "kg-extract-definitions" default_concurrency = 1 +default_triples_batch_size = 50 +default_entity_batch_size = 5 class Processor(FlowProcessor): @@ -33,6 +35,8 @@ class Processor(FlowProcessor): id = params.get("id") concurrency = params.get("concurrency", 1) + self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size) + self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size) super(Processor, self).__init__( **params | { @@ -142,13 +146,13 @@ class Processor(FlowProcessor): s_uri = self.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) - o_value = Value(value=str(o), is_uri=False) + s_value = Term(type=IRI, iri=str(s_uri)) + o_value = Term(type=LITERAL, value=str(o)) triples.append(Triple( s=s_value, p=RDF_LABEL_VALUE, - o=Value(value=s, is_uri=False), + o=Term(type=LITERAL, value=s), )) triples.append(Triple( @@ -158,37 +162,48 @@ class Processor(FlowProcessor): triples.append(Triple( s=s_value, p=SUBJECT_OF_VALUE, - o=Value(value=v.metadata.id, is_uri=True) + o=Term(type=IRI, iri=v.metadata.id) )) - ec = EntityContext( + # Output entity name as context for direct name matching + entities.append(EntityContext( + entity=s_value, + context=s, + )) + + # Output definition as context for semantic matching + entities.append(EntityContext( entity=s_value, context=defn["definition"], + )) + + # Send triples in batches + for i in range(0, len(triples), self.triples_batch_size): + batch = triples[i:i + self.triples_batch_size] + await self.emit_triples( + flow("triples"), + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + batch ) - entities.append(ec) - - await self.emit_triples( - flow("triples"), - Metadata( - id=v.metadata.id, - metadata=[], - user=v.metadata.user, - collection=v.metadata.collection, - ), - triples - ) - - await self.emit_ecs( - flow("entity-contexts"), - Metadata( - id=v.metadata.id, - metadata=[], - user=v.metadata.user, - collection=v.metadata.collection, - ), - entities - ) + # Send entity contexts in batches + for i in range(0, len(entities), self.entity_batch_size): + batch = entities[i:i + self.entity_batch_size] + await self.emit_ecs( + flow("entity-contexts"), + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + batch + ) except Exception as e: logger.error(f"Definitions extraction exception: {e}", exc_info=True) @@ -205,6 +220,20 @@ class Processor(FlowProcessor): help=f'Concurrent processing threads (default: {default_concurrency})' ) + parser.add_argument( + '--triples-batch-size', + type=int, + default=default_triples_batch_size, + help=f'Maximum triples per output message (default: {default_triples_batch_size})' + ) + + parser.add_argument( + '--entity-batch-size', + type=int, + default=default_entity_batch_size, + help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})' + ) + FlowProcessor.add_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py b/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py index 712aadbe..093302a9 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py @@ -74,23 +74,27 @@ def build_entity_uri(entity_name: str, entity_type: str, ontology_id: str, Args: entity_name: Natural language entity name (e.g., "Cornish pasty") - entity_type: Ontology type (e.g., "fo/Recipe") + entity_type: Ontology type (e.g., "fo/Recipe" or "Recipe") ontology_id: Ontology identifier (e.g., "food") base_uri: Base URI for entity URIs (default: "https://trustgraph.ai") Returns: - Full entity URI (e.g., "https://trustgraph.ai/food/fo-recipe-cornish-pasty") + Full entity URI (e.g., "https://trustgraph.ai/food/recipe-cornish-pasty") Examples: >>> build_entity_uri("Cornish pasty", "fo/Recipe", "food") - 'https://trustgraph.ai/food/fo-recipe-cornish-pasty' + 'https://trustgraph.ai/food/recipe-cornish-pasty' - >>> build_entity_uri("Cornish pasty", "fo/Food", "food") - 'https://trustgraph.ai/food/fo-food-cornish-pasty' + >>> build_entity_uri("Cornish pasty", "Food", "food") + 'https://trustgraph.ai/food/food-cornish-pasty' >>> build_entity_uri("beef", "fo/Food", "food") - 'https://trustgraph.ai/food/fo-food-beef' + 'https://trustgraph.ai/food/food-beef' """ + # Strip ontology prefix from type if present (e.g., "fo/Recipe" -> "Recipe") + if "/" in entity_type: + entity_type = entity_type.split("/")[-1] + type_part = normalize_type_identifier(entity_type) name_part = normalize_entity_name(entity_name) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 335f07d2..a0d9a3fe 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -8,7 +8,7 @@ import logging import asyncio from typing import List, Dict, Any, Optional -from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from .... schema import EntityContext, EntityContexts from .... schema import PromptRequest, PromptResponse from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION @@ -27,6 +27,8 @@ logger = logging.getLogger(__name__) default_ident = "kg-extract-ontology" default_concurrency = 1 +default_triples_batch_size = 50 +default_entity_batch_size = 5 # URI prefix mappings for common namespaces URI_PREFIXES = { @@ -39,12 +41,22 @@ URI_PREFIXES = { } +def make_term(v, is_uri): + """Helper to create Term from value and is_uri flag.""" + if is_uri: + return Term(type=IRI, iri=v) + else: + return Term(type=LITERAL, value=v) + + class Processor(FlowProcessor): """Main OntoRAG extraction processor.""" def __init__(self, **params): id = params.get("id", default_ident) concurrency = params.get("concurrency", default_concurrency) + self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size) + self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size) super(Processor, self).__init__( **params | { @@ -274,17 +286,6 @@ class Processor(FlowProcessor): if not ontology_subsets: logger.warning("No relevant ontology elements found for chunk") - # Emit empty outputs - await self.emit_triples( - flow("triples"), - v.metadata, - [] - ) - await self.emit_entity_contexts( - flow("entity-contexts"), - v.metadata, - [] - ) return # Merge subsets if multiple ontologies matched @@ -318,36 +319,29 @@ class Processor(FlowProcessor): # Build entity contexts from all triples (including ontology elements) entity_contexts = self.build_entity_contexts(all_triples) - # Emit all triples (extracted + ontology definitions) - await self.emit_triples( - flow("triples"), - v.metadata, - all_triples - ) + # Emit triples in batches + for i in range(0, len(all_triples), self.triples_batch_size): + batch = all_triples[i:i + self.triples_batch_size] + await self.emit_triples( + flow("triples"), + v.metadata, + batch + ) - # Emit entity contexts - await self.emit_entity_contexts( - flow("entity-contexts"), - v.metadata, - entity_contexts - ) + # Emit entity contexts in batches + for i in range(0, len(entity_contexts), self.entity_batch_size): + batch = entity_contexts[i:i + self.entity_batch_size] + await self.emit_entity_contexts( + flow("entity-contexts"), + v.metadata, + batch + ) logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples " f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts") except Exception as e: logger.error(f"OntoRAG extraction exception: {e}", exc_info=True) - # Emit empty outputs on error - await self.emit_triples( - flow("triples"), - v.metadata, - [] - ) - await self.emit_entity_contexts( - flow("entity-contexts"), - v.metadata, - [] - ) async def extract_with_simplified_format( self, @@ -446,9 +440,9 @@ class Processor(FlowProcessor): is_object_uri = False # Create Triple object with expanded URIs - s_value = Value(value=subject_uri, is_uri=True) - p_value = Value(value=predicate_uri, is_uri=True) - o_value = Value(value=object_uri, is_uri=is_object_uri) + s_value = make_term(subject_uri, is_uri=True) + p_value = make_term(predicate_uri, is_uri=True) + o_value = make_term(object_uri, is_uri=is_object_uri) validated_triples.append(Triple( s=s_value, @@ -609,9 +603,9 @@ class Processor(FlowProcessor): # rdf:type owl:Class ontology_triples.append(Triple( - s=Value(value=class_uri, is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True) + s=make_term(class_uri, is_uri=True), + p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=make_term("http://www.w3.org/2002/07/owl#Class", is_uri=True) )) # rdfs:label (stored as 'labels' in OntologyClass.__dict__) @@ -620,18 +614,18 @@ class Processor(FlowProcessor): if isinstance(labels, list) and labels: label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0]) ontology_triples.append(Triple( - s=Value(value=class_uri, is_uri=True), - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=label_val, is_uri=False) + s=make_term(class_uri, is_uri=True), + p=make_term(RDF_LABEL, is_uri=True), + o=make_term(label_val, is_uri=False) )) # rdfs:comment (stored as 'comment' in OntologyClass.__dict__) if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']: comment = class_def['comment'] ontology_triples.append(Triple( - s=Value(value=class_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), - o=Value(value=comment, is_uri=False) + s=make_term(class_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=make_term(comment, is_uri=False) )) # rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__) @@ -648,9 +642,9 @@ class Processor(FlowProcessor): parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}" ontology_triples.append(Triple( - s=Value(value=class_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True), - o=Value(value=parent_uri, is_uri=True) + s=make_term(class_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True), + o=make_term(parent_uri, is_uri=True) )) # Generate triples for object properties @@ -663,9 +657,9 @@ class Processor(FlowProcessor): # rdf:type owl:ObjectProperty ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=make_term("http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True) )) # rdfs:label (stored as 'labels' in OntologyProperty.__dict__) @@ -674,18 +668,18 @@ class Processor(FlowProcessor): if isinstance(labels, list) and labels: label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0]) ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=label_val, is_uri=False) + s=make_term(prop_uri, is_uri=True), + p=make_term(RDF_LABEL, is_uri=True), + o=make_term(label_val, is_uri=False) )) # rdfs:comment (stored as 'comment' in OntologyProperty.__dict__) if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']: comment = prop_def['comment'] ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), - o=Value(value=comment, is_uri=False) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=make_term(comment, is_uri=False) )) # rdfs:domain (stored as 'domain' in OntologyProperty.__dict__) @@ -702,9 +696,9 @@ class Processor(FlowProcessor): domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), - o=Value(value=domain_uri, is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), + o=make_term(domain_uri, is_uri=True) )) # rdfs:range (stored as 'range' in OntologyProperty.__dict__) @@ -721,9 +715,9 @@ class Processor(FlowProcessor): range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}" ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), - o=Value(value=range_uri, is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), + o=make_term(range_uri, is_uri=True) )) # Generate triples for datatype properties @@ -736,9 +730,9 @@ class Processor(FlowProcessor): # rdf:type owl:DatatypeProperty ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), - o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=make_term("http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True) )) # rdfs:label (stored as 'labels' in OntologyProperty.__dict__) @@ -747,18 +741,18 @@ class Processor(FlowProcessor): if isinstance(labels, list) and labels: label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0]) ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=label_val, is_uri=False) + s=make_term(prop_uri, is_uri=True), + p=make_term(RDF_LABEL, is_uri=True), + o=make_term(label_val, is_uri=False) )) # rdfs:comment (stored as 'comment' in OntologyProperty.__dict__) if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']: comment = prop_def['comment'] ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), - o=Value(value=comment, is_uri=False) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=make_term(comment, is_uri=False) )) # rdfs:domain (stored as 'domain' in OntologyProperty.__dict__) @@ -775,9 +769,9 @@ class Processor(FlowProcessor): domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), - o=Value(value=domain_uri, is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), + o=make_term(domain_uri, is_uri=True) )) # rdfs:range (datatype) @@ -790,9 +784,9 @@ class Processor(FlowProcessor): range_uri = range_val ontology_triples.append(Triple( - s=Value(value=prop_uri, is_uri=True), - p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), - o=Value(value=range_uri, is_uri=True) + s=make_term(prop_uri, is_uri=True), + p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), + o=make_term(range_uri, is_uri=True) )) logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements") @@ -814,9 +808,9 @@ class Processor(FlowProcessor): entity_data = {} # subject_uri -> {labels: [], definitions: []} for triple in triples: - subject_uri = triple.s.value - predicate_uri = triple.p.value - object_val = triple.o.value + subject_uri = triple.s.iri if triple.s.type == IRI else triple.s.value + predicate_uri = triple.p.iri if triple.p.type == IRI else triple.p.value + object_val = triple.o.value if triple.o.type == LITERAL else triple.o.iri # Initialize entity data if not exists if subject_uri not in entity_data: @@ -824,12 +818,12 @@ class Processor(FlowProcessor): # Collect labels (rdfs:label) if predicate_uri == RDF_LABEL: - if not triple.o.is_uri: # Labels are literals + if triple.o.type == LITERAL: # Labels are literals entity_data[subject_uri]['labels'].append(object_val) # Collect definitions (skos:definition, schema:description) elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description": - if not triple.o.is_uri: + if triple.o.type == LITERAL: entity_data[subject_uri]['definitions'].append(object_val) # Build EntityContext objects @@ -848,7 +842,7 @@ class Processor(FlowProcessor): if context_parts: context_text = ". ".join(context_parts) entity_contexts.append(EntityContext( - entity=Value(value=subject_uri, is_uri=True), + entity=make_term(subject_uri, is_uri=True), context=context_text )) @@ -876,6 +870,18 @@ class Processor(FlowProcessor): default=0.3, help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' ) + parser.add_argument( + '--triples-batch-size', + type=int, + default=default_triples_batch_size, + help=f'Maximum triples per output message (default: {default_triples_batch_size})' + ) + parser.add_argument( + '--entity-batch-size', + type=int, + default=default_entity_batch_size, + help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})' + ) FlowProcessor.add_args(parser) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py b/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py index 3131d977..1f54222d 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py @@ -49,8 +49,17 @@ class ExtractionResult: def parse_extraction_response(response: Any) -> Optional[ExtractionResult]: """Parse LLM extraction response into structured format. + Supports two formats: + 1. JSONL format (list): Flat list of objects with 'type' discriminator field + [{"type": "entity", ...}, {"type": "relationship", ...}, {"type": "attribute", ...}] + 2. Legacy format (dict): Nested structure with separate arrays + {"entities": [...], "relationships": [...], "attributes": [...]} + Args: - response: LLM response (string JSON or already parsed dict) + response: LLM response - can be: + - string (JSON to parse) + - dict (legacy nested format) + - list (JSONL format - flat list with type discriminators) Returns: ExtractionResult with parsed entities/relationships/attributes, @@ -64,17 +73,89 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]: logger.error(f"Failed to parse JSON response: {e}") logger.debug(f"Response was: {response[:500]}") return None - elif isinstance(response, dict): + elif isinstance(response, (dict, list)): data = response else: logger.error(f"Unexpected response type: {type(response)}") return None - # Validate structure - if not isinstance(data, dict): - logger.error(f"Expected dict, got {type(data)}") - return None + # Handle JSONL format (flat list with type discriminators) + if isinstance(data, list): + return parse_jsonl_format(data) + # Handle legacy format (nested dict) + if isinstance(data, dict): + return parse_legacy_format(data) + + logger.error(f"Expected dict or list, got {type(data)}") + return None + + +def parse_jsonl_format(data: List[Dict[str, Any]]) -> ExtractionResult: + """Parse JSONL format response (flat list with type discriminators). + + Each item has a 'type' field: 'entity', 'relationship', or 'attribute'. + + Args: + data: List of dicts with type discriminator + + Returns: + ExtractionResult with categorized items + """ + entities = [] + relationships = [] + attributes = [] + + for item in data: + if not isinstance(item, dict): + logger.warning(f"Skipping non-dict item: {type(item)}") + continue + + item_type = item.get('type') + + if item_type == 'entity': + try: + entity = parse_entity_jsonl(item) + if entity: + entities.append(entity) + except Exception as e: + logger.warning(f"Failed to parse entity {item}: {e}") + + elif item_type == 'relationship': + try: + relationship = parse_relationship(item) + if relationship: + relationships.append(relationship) + except Exception as e: + logger.warning(f"Failed to parse relationship {item}: {e}") + + elif item_type == 'attribute': + try: + attribute = parse_attribute(item) + if attribute: + attributes.append(attribute) + except Exception as e: + logger.warning(f"Failed to parse attribute {item}: {e}") + + else: + logger.warning(f"Unknown item type '{item_type}': {item}") + + return ExtractionResult( + entities=entities, + relationships=relationships, + attributes=attributes + ) + + +def parse_legacy_format(data: Dict[str, Any]) -> ExtractionResult: + """Parse legacy format response (nested dict with arrays). + + Args: + data: Dict with 'entities', 'relationships', 'attributes' arrays + + Returns: + ExtractionResult with parsed items + """ # Parse entities entities = [] entities_data = data.get('entities', []) @@ -127,6 +208,37 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]: ) +def parse_entity_jsonl(data: Dict[str, Any]) -> Optional[Entity]: + """Parse entity from JSONL format dict. + + JSONL format uses 'entity_type' instead of 'type' for the entity's type + (since 'type' is the discriminator field). + + Args: + data: Entity dict with 'entity' and 'entity_type' fields + + Returns: + Entity object or None if invalid + """ + if not isinstance(data, dict): + logger.warning(f"Entity data is not a dict: {type(data)}") + return None + + entity = data.get('entity') + # JSONL format uses 'entity_type' since 'type' is the discriminator + entity_type = data.get('entity_type') + + if not entity or not entity_type: + logger.warning(f"Missing required fields in entity: {data}") + return None + + if not isinstance(entity, str) or not isinstance(entity_type, str): + logger.warning(f"Entity fields must be strings: {data}") + return None + + return Entity(entity=entity, type=entity_type) + + def parse_entity(data: Dict[str, Any]) -> Optional[Entity]: """Parse entity from dict. diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py index 2eb43b19..06fff4f4 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py @@ -8,7 +8,7 @@ with full URIs and correct is_uri flags. import logging from typing import List, Optional -from .... schema import Triple, Value +from .... schema import Triple, Term, IRI, LITERAL from .... rdf import RDF_TYPE, RDF_LABEL from .simplified_parser import Entity, Relationship, Attribute, ExtractionResult @@ -87,17 +87,17 @@ class TripleConverter: # Generate type triple: entity rdf:type ClassURI type_triple = Triple( - s=Value(value=entity_uri, is_uri=True), - p=Value(value=RDF_TYPE, is_uri=True), - o=Value(value=class_uri, is_uri=True) + s=Term(type=IRI, iri=entity_uri), + p=Term(type=IRI, iri=RDF_TYPE), + o=Term(type=IRI, iri=class_uri) ) triples.append(type_triple) # Generate label triple: entity rdfs:label "entity name" label_triple = Triple( - s=Value(value=entity_uri, is_uri=True), - p=Value(value=RDF_LABEL, is_uri=True), - o=Value(value=entity.entity, is_uri=False) # Literal! + s=Term(type=IRI, iri=entity_uri), + p=Term(type=IRI, iri=RDF_LABEL), + o=Term(type=LITERAL, value=entity.entity) # Literal! ) triples.append(label_triple) @@ -131,9 +131,9 @@ class TripleConverter: # Generate triple: subject property object return Triple( - s=Value(value=subject_uri, is_uri=True), - p=Value(value=property_uri, is_uri=True), - o=Value(value=object_uri, is_uri=True) + s=Term(type=IRI, iri=subject_uri), + p=Term(type=IRI, iri=property_uri), + o=Term(type=IRI, iri=object_uri) ) def convert_attribute(self, attribute: Attribute) -> Optional[Triple]: @@ -159,9 +159,9 @@ class TripleConverter: # Generate triple: entity property "literal value" return Triple( - s=Value(value=entity_uri, is_uri=True), - p=Value(value=property_uri, is_uri=True), - o=Value(value=attribute.value, is_uri=False) # Literal! + s=Term(type=IRI, iri=entity_uri), + p=Term(type=IRI, iri=property_uri), + o=Term(type=LITERAL, value=attribute.value) # Literal! ) def _get_class_uri(self, class_id: str) -> Optional[str]: diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 6d461997..7ab51555 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -13,18 +13,19 @@ import urllib.parse logger = logging.getLogger(__name__) from .... schema import Chunk, Triple, Triples -from .... schema import Metadata, Value +from .... schema import Metadata, Term, IRI, LITERAL from .... schema import PromptRequest, PromptResponse from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import PromptClientSpec -RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) -SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) +RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL) +SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF) default_ident = "kg-extract-relationships" default_concurrency = 1 +default_triples_batch_size = 50 class Processor(FlowProcessor): @@ -32,6 +33,7 @@ class Processor(FlowProcessor): id = params.get("id") concurrency = params.get("concurrency", 1) + self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size) super(Processor, self).__init__( **params | { @@ -127,16 +129,16 @@ class Processor(FlowProcessor): if o is None: continue s_uri = self.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) + s_value = Term(type=IRI, iri=str(s_uri)) p_uri = self.to_uri(p) - p_value = Value(value=str(p_uri), is_uri=True) + p_value = Term(type=IRI, iri=str(p_uri)) - if rel["object-entity"]: + if rel["object-entity"]: o_uri = self.to_uri(o) - o_value = Value(value=str(o_uri), is_uri=True) + o_value = Term(type=IRI, iri=str(o_uri)) else: - o_value = Value(value=str(o), is_uri=False) + o_value = Term(type=LITERAL, value=str(o)) triples.append(Triple( s=s_value, @@ -148,14 +150,14 @@ class Processor(FlowProcessor): triples.append(Triple( s=s_value, p=RDF_LABEL_VALUE, - o=Value(value=str(s), is_uri=False) + o=Term(type=LITERAL, value=str(s)) )) # Label for p triples.append(Triple( s=p_value, p=RDF_LABEL_VALUE, - o=Value(value=str(p), is_uri=False) + o=Term(type=LITERAL, value=str(p)) )) if rel["object-entity"]: @@ -163,14 +165,14 @@ class Processor(FlowProcessor): triples.append(Triple( s=o_value, p=RDF_LABEL_VALUE, - o=Value(value=str(o), is_uri=False) + o=Term(type=LITERAL, value=str(o)) )) # 'Subject of' for s triples.append(Triple( s=s_value, p=SUBJECT_OF_VALUE, - o=Value(value=v.metadata.id, is_uri=True) + o=Term(type=IRI, iri=v.metadata.id) )) if rel["object-entity"]: @@ -178,19 +180,22 @@ class Processor(FlowProcessor): triples.append(Triple( s=o_value, p=SUBJECT_OF_VALUE, - o=Value(value=v.metadata.id, is_uri=True) + o=Term(type=IRI, iri=v.metadata.id) )) - await self.emit_triples( - flow("triples"), - Metadata( - id=v.metadata.id, - metadata=[], - user=v.metadata.user, - collection=v.metadata.collection, - ), - triples - ) + # Send triples in batches + for i in range(0, len(triples), self.triples_batch_size): + batch = triples[i:i + self.triples_batch_size] + await self.emit_triples( + flow("triples"), + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + batch + ) except Exception as e: logger.error(f"Relationship extraction exception: {e}", exc_info=True) @@ -207,6 +212,13 @@ class Processor(FlowProcessor): help=f'Concurrent processing threads (default: {default_concurrency})' ) + parser.add_argument( + '--triples-batch-size', + type=int, + default=default_triples_batch_size, + help=f'Maximum triples per output message (default: {default_triples_batch_size})' + ) + FlowProcessor.add_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/__init__.py b/trustgraph-flow/trustgraph/extract/kg/rows/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/extract/kg/objects/__init__.py rename to trustgraph-flow/trustgraph/extract/kg/rows/__init__.py diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/__main__.py b/trustgraph-flow/trustgraph/extract/kg/rows/__main__.py similarity index 100% rename from trustgraph-flow/trustgraph/extract/kg/objects/__main__.py rename to trustgraph-flow/trustgraph/extract/kg/rows/__main__.py diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py similarity index 98% rename from trustgraph-flow/trustgraph/extract/kg/objects/processor.py rename to trustgraph-flow/trustgraph/extract/kg/rows/processor.py index b3483240..bd7bc802 100644 --- a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -1,5 +1,5 @@ """ -Object extraction service - extracts structured objects from text chunks +Row extraction service - extracts structured rows from text chunks based on configured schemas. """ @@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import PromptClientSpec from .... messaging.translators import row_schema_translator -default_ident = "kg-extract-objects" +default_ident = "kg-extract-rows" def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]: @@ -310,5 +310,5 @@ class Processor(FlowProcessor): FlowProcessor.add_args(parser) def run(): - """Entry point for kg-extract-objects command""" + """Entry point for kg-extract-rows command""" Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index 129cc64c..206d14d0 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -11,7 +11,7 @@ import logging # Module logger logger = logging.getLogger(__name__) -from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue @@ -20,7 +20,7 @@ from .... clients.prompt_client import PromptClient from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION from .... base import ConsumerProducer -DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) +DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION) module = "kg-extract-topics" @@ -106,8 +106,8 @@ class Processor(ConsumerProducer): s_uri = self.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) - o_value = Value(value=str(o), is_uri=False) + s_value = Term(type=IRI, iri=str(s_uri)) + o_value = Term(type=LITERAL, value=str(o)) await self.emit_edge( v.metadata, s_value, DEFINITION_VALUE, o_value diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py new file mode 100644 index 00000000..650d4f40 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py @@ -0,0 +1,31 @@ + +from ... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class DocumentEmbeddingsQueryRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(DocumentEmbeddingsQueryRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=DocumentEmbeddingsRequest, + response_schema=DocumentEmbeddingsResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("document-embeddings-query") + self.response_translator = TranslatorRegistry.get_response_translator("document-embeddings-query") + + def to_request(self, body): + return self.request_translator.to_pulsar(body) + + def from_response(self, message): + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 0766e232..35edad76 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -20,12 +20,14 @@ from . prompt import PromptRequestor from . graph_rag import GraphRagRequestor from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor -from . objects_query import ObjectsQueryRequestor +from . rows_query import RowsQueryRequestor from . nlp_query import NLPQueryRequestor from . structured_query import StructuredQueryRequestor from . structured_diag import StructuredDiagRequestor from . embeddings import EmbeddingsRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor +from . document_embeddings_query import DocumentEmbeddingsQueryRequestor +from . row_embeddings_query import RowEmbeddingsQueryRequestor from . mcp_tool import McpToolRequestor from . text_load import TextLoad from . document_load import DocumentLoad @@ -39,7 +41,7 @@ from . triples_import import TriplesImport from . graph_embeddings_import import GraphEmbeddingsImport from . document_embeddings_import import DocumentEmbeddingsImport from . entity_contexts_import import EntityContextsImport -from . objects_import import ObjectsImport +from . rows_import import RowsImport from . core_export import CoreExport from . core_import import CoreImport @@ -55,11 +57,13 @@ request_response_dispatchers = { "document-rag": DocumentRagRequestor, "embeddings": EmbeddingsRequestor, "graph-embeddings": GraphEmbeddingsQueryRequestor, + "document-embeddings": DocumentEmbeddingsQueryRequestor, "triples": TriplesQueryRequestor, - "objects": ObjectsQueryRequestor, + "rows": RowsQueryRequestor, "nlp-query": NLPQueryRequestor, "structured-query": StructuredQueryRequestor, "structured-diag": StructuredDiagRequestor, + "row-embeddings": RowEmbeddingsQueryRequestor, } global_dispatchers = { @@ -87,7 +91,7 @@ import_dispatchers = { "graph-embeddings": GraphEmbeddingsImport, "document-embeddings": DocumentEmbeddingsImport, "entity-contexts": EntityContextsImport, - "objects": ObjectsImport, + "rows": RowsImport, } class DispatcherWrapper: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py new file mode 100644 index 00000000..8b139fc2 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py @@ -0,0 +1,31 @@ + +from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class RowEmbeddingsQueryRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(RowEmbeddingsQueryRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=RowEmbeddingsRequest, + response_schema=RowEmbeddingsResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query") + self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query") + + def to_request(self, body): + return self.request_translator.to_pulsar(body) + + def from_response(self, message): + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py similarity index 97% rename from trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py rename to trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index fc982b69..6606dc1a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -12,7 +12,7 @@ from . serialize import to_subgraph # Module logger logger = logging.getLogger(__name__) -class ObjectsImport: +class RowsImport: def __init__( self, ws, running, backend, queue @@ -20,7 +20,7 @@ class ObjectsImport: self.ws = ws self.running = running - + self.publisher = Publisher( backend, topic = queue, schema = ExtractedObject ) @@ -73,4 +73,4 @@ class ObjectsImport: if self.ws: await self.ws.close() - self.ws = None \ No newline at end of file + self.ws = None diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py similarity index 69% rename from trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py rename to trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py index fb8dc81d..57435be8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py @@ -1,30 +1,30 @@ -from ... schema import ObjectsQueryRequest, ObjectsQueryResponse +from ... schema import RowsQueryRequest, RowsQueryResponse from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -class ObjectsQueryRequestor(ServiceRequestor): +class RowsQueryRequestor(ServiceRequestor): def __init__( self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): - super(ObjectsQueryRequestor, self).__init__( + super(RowsQueryRequestor, self).__init__( backend=backend, request_queue=request_queue, response_queue=response_queue, - request_schema=ObjectsQueryRequest, - response_schema=ObjectsQueryResponse, + request_schema=RowsQueryRequest, + response_schema=RowsQueryResponse, subscription = subscriber, consumer_name = consumer, timeout=timeout, ) - self.request_translator = TranslatorRegistry.get_request_translator("objects-query") - self.response_translator = TranslatorRegistry.get_response_translator("objects-query") + self.request_translator = TranslatorRegistry.get_request_translator("rows-query") + self.response_translator = TranslatorRegistry.get_response_translator("rows-query") def to_request(self, body): return self.request_translator.to_pulsar(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 653ecfd9..8f1cdece 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -1,46 +1,37 @@ import base64 -from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata +from ... schema import Term, Triple, DocumentMetadata, ProcessingMetadata +from ... messaging.translators.primitives import TermTranslator, TripleTranslator + +# Singleton translator instances +_term_translator = TermTranslator() +_triple_translator = TripleTranslator() -# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators -# Use the new messaging translation system instead for consistency and reusability. -# Examples: -# from trustgraph.... messaging.translators.primitives import ValueTranslator -# value_translator = ValueTranslator() -# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True}) def to_value(x): - return Value(value=x["v"], is_uri=x["e"]) + """Convert dict to Term. Delegates to TermTranslator.""" + return _term_translator.to_pulsar(x) + def to_subgraph(x): - return [ - Triple( - s=to_value(t["s"]), - p=to_value(t["p"]), - o=to_value(t["o"]) - ) - for t in x - ] + """Convert list of dicts to list of Triples. Delegates to TripleTranslator.""" + return [_triple_translator.to_pulsar(t) for t in x] + def serialize_value(v): - return { - "v": v.value, - "e": v.is_uri, - } + """Convert Term to dict. Delegates to TermTranslator.""" + return _term_translator.from_pulsar(v) + def serialize_triple(t): - return { - "s": serialize_value(t.s), - "p": serialize_value(t.p), - "o": serialize_value(t.o) - } + """Convert Triple to dict. Delegates to TripleTranslator.""" + return _triple_translator.from_pulsar(t) + def serialize_subgraph(sg): - return [ - serialize_triple(t) - for t in sg - ] + """Convert list of Triples to list of dicts.""" + return [serialize_triple(t) for t in sg] def serialize_triples(message): return { diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index 614c1362..4e3db7f9 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -75,6 +75,7 @@ class Processor(LlmService): if stream: data["stream"] = True + data["stream_options"] = {"include_usage": True} body = json.dumps(data) @@ -191,6 +192,9 @@ class Processor(LlmService): if response.status_code != 200: raise RuntimeError("LLM failure") + total_input_tokens = 0 + total_output_tokens = 0 + # Parse SSE stream for line in response.iter_lines(): if line: @@ -215,15 +219,21 @@ class Processor(LlmService): model=model_name, is_final=False ) + + # Capture usage from final chunk + if 'usage' in chunk_data and chunk_data['usage']: + total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0) + total_output_tokens = chunk_data['usage'].get('completion_tokens', 0) + except json.JSONDecodeError: logger.warning(f"Failed to parse chunk: {data}") continue - # Send final chunk + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index 950c006a..4ab0b302 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -161,9 +161,13 @@ class Processor(LlmService): temperature=effective_temperature, max_tokens=self.max_output, top_p=1, - stream=True # Enable streaming + stream=True, + stream_options={"include_usage": True} ) + total_input_tokens = 0 + total_output_tokens = 0 + # Stream chunks for chunk in response: if chunk.choices and chunk.choices[0].delta.content: @@ -175,11 +179,16 @@ class Processor(LlmService): is_final=False ) - # Send final chunk + # Capture usage from final chunk + if chunk.usage: + total_input_tokens = chunk.usage.prompt_tokens + total_output_tokens = chunk.usage.completion_tokens + + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index 801ed067..276727b5 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -126,9 +126,13 @@ class Processor(LlmService): frequency_penalty=0, presence_penalty=0, response_format={"type": "text"}, - stream=True + stream=True, + stream_options={"include_usage": True} ) + total_input_tokens = 0 + total_output_tokens = 0 + for chunk in response: if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( @@ -139,10 +143,15 @@ class Processor(LlmService): is_final=False ) + # Capture usage from final chunk + if chunk.usage: + total_input_tokens = chunk.usage.prompt_tokens + total_output_tokens = chunk.usage.completion_tokens + yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py index 555d5c94..b057f58d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py @@ -40,7 +40,7 @@ class Processor(LlmService): ) self.default_model = model - self.url = url + "v1/" + self.url = url.rstrip('/') + "/v1/" self.temperature = temperature self.max_output = max_output self.openai = OpenAI( @@ -130,9 +130,13 @@ class Processor(LlmService): frequency_penalty=0, presence_penalty=0, response_format={"type": "text"}, - stream=True + stream=True, + stream_options={"include_usage": True} ) + total_input_tokens = 0 + total_output_tokens = 0 + for chunk in response: if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( @@ -143,10 +147,15 @@ class Processor(LlmService): is_final=False ) + # Capture usage from final chunk + if chunk.usage: + total_input_tokens = chunk.usage.prompt_tokens + total_output_tokens = chunk.usage.completion_tokens + yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py index 7952b1df..fab41ecd 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -156,6 +156,9 @@ class Processor(LlmService): response_format={"type": "text"} ) + total_input_tokens = 0 + total_output_tokens = 0 + for chunk in stream: if chunk.data.choices and chunk.data.choices[0].delta.content: yield LlmChunk( @@ -166,11 +169,16 @@ class Processor(LlmService): is_final=False ) - # Send final chunk + # Capture usage data when available (typically in final chunk) + if chunk.data.usage: + total_input_tokens = chunk.data.usage.prompt_tokens + total_output_tokens = chunk.data.usage.completion_tokens + + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 4da1378b..d65e27bf 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -153,9 +153,13 @@ class Processor(LlmService): ], temperature=effective_temperature, max_tokens=self.max_output, - stream=True # Enable streaming + stream=True, + stream_options={"include_usage": True} ) + total_input_tokens = 0 + total_output_tokens = 0 + # Stream chunks for chunk in response: if chunk.choices and chunk.choices[0].delta.content: @@ -167,12 +171,16 @@ class Processor(LlmService): is_final=False ) - # Note: OpenAI doesn't provide token counts in streaming mode - # Send final chunk without token counts + # Capture usage from final chunk + if chunk.usage: + total_input_tokens = chunk.usage.prompt_tokens + total_output_tokens = chunk.usage.completion_tokens + + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py b/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py index 63f8dbc4..5caeb9be 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py @@ -83,7 +83,7 @@ class Processor(LlmService): try: - url = f"{self.base_url}/chat/completions" + url = f"{self.base_url.rstrip('/')}/chat/completions" async with self.session.post( url, @@ -152,10 +152,14 @@ class Processor(LlmService): "max_tokens": self.max_output, "temperature": effective_temperature, "stream": True, + "stream_options": {"include_usage": True}, } try: - url = f"{self.base_url}/chat/completions" + url = f"{self.base_url.rstrip('/')}/chat/completions" + + total_input_tokens = 0 + total_output_tokens = 0 async with self.session.post( url, @@ -196,15 +200,21 @@ class Processor(LlmService): model=model_name, is_final=False ) + + # Capture usage from final chunk + if 'usage' in chunk_data and chunk_data['usage']: + total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0) + total_output_tokens = chunk_data['usage'].get('completion_tokens', 0) + except json.JSONDecodeError: logger.warning(f"Failed to parse chunk: {data}") continue - # Send final chunk + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py index af27830c..2dd4576e 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py @@ -75,7 +75,7 @@ class Processor(LlmService): try: - url = f"{self.base_url}/completions" + url = f"{self.base_url.rstrip('/')}/completions" async with self.session.post( url, @@ -135,10 +135,14 @@ class Processor(LlmService): "max_tokens": self.max_output, "temperature": effective_temperature, "stream": True, + "stream_options": {"include_usage": True}, } try: - url = f"{self.base_url}/completions" + url = f"{self.base_url.rstrip('/')}/completions" + + total_input_tokens = 0 + total_output_tokens = 0 async with self.session.post( url, @@ -177,15 +181,21 @@ class Processor(LlmService): model=model_name, is_final=False ) + + # Capture usage from final chunk + if 'usage' in chunk_data and chunk_data['usage']: + total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0) + total_output_tokens = chunk_data['usage'].get('completion_tokens', 0) + except json.JSONDecodeError: logger.warning(f"Failed to parse chunk: {data}") continue - # Send final chunk + # Send final chunk with token counts yield LlmChunk( text="", - in_token=None, - out_token=None, + in_token=total_input_tokens, + out_token=total_output_tokens, model=model_name, is_final=True ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 2915184c..03c98ad3 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -8,13 +8,13 @@ import logging from .... direct.milvus_doc_embeddings import DocVectors from .... schema import DocumentEmbeddingsResponse -from .... schema import Error, Value +from .... schema import Error from .... base import DocumentEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "de-query" +default_ident = "doc-embeddings-query" default_store_uri = 'http://localhost:19530' class Processor(DocumentEmbeddingsQueryService): diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index f0d66021..1c3f8d1b 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -16,7 +16,7 @@ from .... base import DocumentEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "de-query" +default_ident = "doc-embeddings-query" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") class Processor(DocumentEmbeddingsQueryService): diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 46e9e687..e84372cb 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams from .... schema import DocumentEmbeddingsResponse -from .... schema import Error, Value +from .... schema import Error from .... base import DocumentEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "de-query" +default_ident = "doc-embeddings-query" default_store_uri = 'http://localhost:6333' diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index cb9255c2..c5cdb6d8 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -8,13 +8,13 @@ import logging from .... direct.milvus_graph_embeddings import EntityVectors from .... schema import GraphEmbeddingsResponse -from .... schema import Error, Value +from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-query" +default_ident = "graph-embeddings-query" default_store_uri = 'http://localhost:19530' class Processor(GraphEmbeddingsQueryService): @@ -33,9 +33,9 @@ class Processor(GraphEmbeddingsQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_graph_embeddings(self, msg): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index f6277e4f..5882f21c 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -12,13 +12,13 @@ from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig from .... schema import GraphEmbeddingsResponse -from .... schema import Error, Value +from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-query" +default_ident = "graph-embeddings-query" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") class Processor(GraphEmbeddingsQueryService): @@ -51,9 +51,9 @@ class Processor(GraphEmbeddingsQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_graph_embeddings(self, msg): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 513fd2e4..a76059ef 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams from .... schema import GraphEmbeddingsResponse -from .... schema import Error, Value +from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-query" +default_ident = "graph-embeddings-query" default_store_uri = 'http://localhost:6333' @@ -67,9 +67,9 @@ class Processor(GraphEmbeddingsQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_graph_embeddings(self, msg): diff --git a/trustgraph-flow/trustgraph/query/graphql/__init__.py b/trustgraph-flow/trustgraph/query/graphql/__init__.py new file mode 100644 index 00000000..32dc6a97 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/__init__.py @@ -0,0 +1,22 @@ +""" +Shared GraphQL utilities for row query services. + +This module provides reusable GraphQL components including: +- Filter types (IntFilter, StringFilter, FloatFilter) +- Dynamic schema generation from RowSchema definitions +- Filter parsing utilities +""" + +from .types import IntFilter, StringFilter, FloatFilter, SortDirection +from .schema import GraphQLSchemaBuilder +from .filters import parse_filter_key, parse_where_clause + +__all__ = [ + "IntFilter", + "StringFilter", + "FloatFilter", + "SortDirection", + "GraphQLSchemaBuilder", + "parse_filter_key", + "parse_where_clause", +] diff --git a/trustgraph-flow/trustgraph/query/graphql/filters.py b/trustgraph-flow/trustgraph/query/graphql/filters.py new file mode 100644 index 00000000..7788e20d --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/filters.py @@ -0,0 +1,104 @@ +""" +Filter parsing utilities for GraphQL row queries. + +Provides functions to parse GraphQL filter objects into a normalized +format that can be used by different query backends. +""" + +import logging +from typing import Dict, Any, Tuple + +logger = logging.getLogger(__name__) + + +def parse_filter_key(filter_key: str) -> Tuple[str, str]: + """ + Parse GraphQL filter key into field name and operator. + + Supports common GraphQL filter patterns: + - field_name -> (field_name, "eq") + - field_name_gt -> (field_name, "gt") + - field_name_gte -> (field_name, "gte") + - field_name_lt -> (field_name, "lt") + - field_name_lte -> (field_name, "lte") + - field_name_in -> (field_name, "in") + + Args: + filter_key: The filter key string from GraphQL + + Returns: + Tuple of (field_name, operator) + """ + if not filter_key: + return ("", "eq") + + operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"] + + for op_suffix in operators: + if filter_key.endswith(op_suffix): + field_name = filter_key[:-len(op_suffix)] + operator = op_suffix[1:] # Remove the leading underscore + return (field_name, operator) + + # Default to equality if no operator suffix found + return (filter_key, "eq") + + +def parse_where_clause(where_obj) -> Dict[str, Any]: + """ + Parse the idiomatic nested GraphQL filter structure into a flat dict. + + Converts Strawberry filter objects (StringFilter, IntFilter, etc.) + into a dictionary mapping field names with operators to values. + + Example: + Input: where_obj with email.eq = "foo@bar.com" + Output: {"email": "foo@bar.com"} + + Input: where_obj with age.gt = 21 + Output: {"age_gt": 21} + + Args: + where_obj: The GraphQL where clause object + + Returns: + Dictionary mapping field_operator keys to values + """ + if not where_obj: + return {} + + conditions = {} + + logger.debug(f"Parsing where clause: {where_obj}") + + for field_name, filter_obj in where_obj.__dict__.items(): + if filter_obj is None: + continue + + logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}") + + if hasattr(filter_obj, '__dict__'): + # This is a filter object (StringFilter, IntFilter, etc.) + for operator, value in filter_obj.__dict__.items(): + if value is not None: + logger.debug(f"Found operator {operator} with value {value}") + # Map GraphQL operators to our internal format + if operator == "eq": + conditions[field_name] = value + elif operator in ["gt", "gte", "lt", "lte"]: + conditions[f"{field_name}_{operator}"] = value + elif operator == "in_": + conditions[f"{field_name}_in"] = value + elif operator == "contains": + conditions[f"{field_name}_contains"] = value + elif operator == "startsWith": + conditions[f"{field_name}_startsWith"] = value + elif operator == "endsWith": + conditions[f"{field_name}_endsWith"] = value + elif operator == "not_": + conditions[f"{field_name}_not"] = value + elif operator == "not_in": + conditions[f"{field_name}_not_in"] = value + + logger.debug(f"Final parsed conditions: {conditions}") + return conditions diff --git a/trustgraph-flow/trustgraph/query/graphql/schema.py b/trustgraph-flow/trustgraph/query/graphql/schema.py new file mode 100644 index 00000000..0c97b1d9 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/schema.py @@ -0,0 +1,251 @@ +""" +Dynamic GraphQL schema generation from RowSchema definitions. + +Provides a builder class that creates Strawberry GraphQL schemas +from TrustGraph RowSchema definitions, with pluggable query backends. +""" + +import logging +from typing import Dict, Any, Optional, List, Callable, Awaitable + +import strawberry +from strawberry import Schema +from strawberry.types import Info + +from .types import IntFilter, StringFilter, FloatFilter, SortDirection + +logger = logging.getLogger(__name__) + +# Type alias for query callback function +QueryCallback = Callable[ + [str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]], + Awaitable[List[Dict[str, Any]]] +] + + +class GraphQLSchemaBuilder: + """ + Builds GraphQL schemas from RowSchema definitions. + + This class extracts the GraphQL schema generation logic so it can be + reused across different query backends (Cassandra, etc.). + + Usage: + builder = GraphQLSchemaBuilder() + + # Add schemas + for name, row_schema in schemas.items(): + builder.add_schema(name, row_schema) + + # Build with a query callback + schema = builder.build(query_callback) + """ + + def __init__(self): + self.schemas: Dict[str, Any] = {} # name -> RowSchema + self.graphql_types: Dict[str, type] = {} + self.filter_types: Dict[str, type] = {} + + def add_schema(self, name: str, row_schema) -> None: + """ + Add a RowSchema to the builder. + + Args: + name: The schema name (used as the GraphQL query field name) + row_schema: The RowSchema object defining fields + """ + self.schemas[name] = row_schema + self.graphql_types[name] = self._create_graphql_type(name, row_schema) + self.filter_types[name] = self._create_filter_type(name, row_schema) + logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields") + + def clear(self) -> None: + """Clear all schemas from the builder.""" + self.schemas = {} + self.graphql_types = {} + self.filter_types = {} + + def build(self, query_callback: QueryCallback) -> Optional[Schema]: + """ + Build the GraphQL schema with the provided query callback. + + The query callback will be invoked when resolving queries, with: + - user: str + - collection: str + - schema_name: str + - row_schema: RowSchema + - filters: Dict[str, Any] + - limit: int + - order_by: Optional[str] + - direction: Optional[SortDirection] + + It should return a list of row dictionaries. + + Args: + query_callback: Async function to execute queries + + Returns: + Strawberry Schema, or None if no schemas are loaded + """ + if not self.schemas: + logger.warning("No schemas loaded, cannot generate GraphQL schema") + return None + + # Create the Query class with resolvers + query_dict = {'__annotations__': {}} + + for schema_name, row_schema in self.schemas.items(): + graphql_type = self.graphql_types[schema_name] + filter_type = self.filter_types[schema_name] + + # Create resolver function for this schema + resolver_func = self._make_resolver( + schema_name, row_schema, graphql_type, filter_type, query_callback + ) + + # Add field to query dictionary + query_dict[schema_name] = strawberry.field(resolver=resolver_func) + query_dict['__annotations__'][schema_name] = List[graphql_type] + + # Create the Query class + Query = type('Query', (), query_dict) + Query = strawberry.type(Query) + + # Create the schema with auto_camel_case disabled to keep snake_case field names + schema = strawberry.Schema( + query=Query, + config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False) + ) + logger.info(f"Generated GraphQL schema with {len(self.schemas)} types") + return schema + + def _get_python_type(self, field_type: str): + """Convert schema field type to Python type for GraphQL.""" + type_mapping = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "timestamp": str, # Use string for timestamps in GraphQL + "date": str, + "time": str, + "uuid": str + } + return type_mapping.get(field_type, str) + + def _create_graphql_type(self, schema_name: str, row_schema) -> type: + """Create a GraphQL output type from a RowSchema.""" + # Create annotations for the GraphQL type + annotations = {} + defaults = {} + + for field in row_schema.fields: + python_type = self._get_python_type(field.type) + + # Make field optional if not required + if not field.required and not field.primary: + annotations[field.name] = Optional[python_type] + defaults[field.name] = None + else: + annotations[field.name] = python_type + + # Create the class dynamically + type_name = f"{schema_name.capitalize()}Type" + graphql_class = type( + type_name, + (), + { + "__annotations__": annotations, + **defaults + } + ) + + # Apply strawberry decorator + return strawberry.type(graphql_class) + + def _create_filter_type(self, schema_name: str, row_schema) -> type: + """Create a dynamic filter input type for a schema.""" + filter_type_name = f"{schema_name.capitalize()}Filter" + + # Add __annotations__ and defaults for the fields + annotations = {} + defaults = {} + + logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}") + + for field in row_schema.fields: + logger.debug( + f"Field {field.name}: type={field.type}, " + f"indexed={field.indexed}, primary={field.primary}" + ) + + # Allow filtering on any field + if field.type == "integer": + annotations[field.name] = Optional[IntFilter] + defaults[field.name] = None + elif field.type == "float": + annotations[field.name] = Optional[FloatFilter] + defaults[field.name] = None + elif field.type == "string": + annotations[field.name] = Optional[StringFilter] + defaults[field.name] = None + + logger.debug( + f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}" + ) + + # Create the class dynamically + FilterType = type( + filter_type_name, + (), + { + "__annotations__": annotations, + **defaults + } + ) + + # Apply strawberry input decorator + FilterType = strawberry.input(FilterType) + + return FilterType + + def _make_resolver( + self, + schema_name: str, + row_schema, + graphql_type: type, + filter_type: type, + query_callback: QueryCallback + ): + """Create a resolver function for a schema.""" + from .filters import parse_where_clause + + async def resolver( + info: Info, + where: Optional[filter_type] = None, + order_by: Optional[str] = None, + direction: Optional[SortDirection] = None, + limit: Optional[int] = 100 + ) -> List[graphql_type]: + # Get context values + user = info.context["user"] + collection = info.context["collection"] + + # Parse the where clause + filters = parse_where_clause(where) + + # Call the query backend + results = await query_callback( + user, collection, schema_name, row_schema, + filters, limit, order_by, direction + ) + + # Convert to GraphQL types + graphql_results = [] + for row in results: + graphql_obj = graphql_type(**row) + graphql_results.append(graphql_obj) + + return graphql_results + + return resolver diff --git a/trustgraph-flow/trustgraph/query/graphql/types.py b/trustgraph-flow/trustgraph/query/graphql/types.py new file mode 100644 index 00000000..4d288bb6 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/types.py @@ -0,0 +1,56 @@ +""" +GraphQL filter and sort types for row queries. + +These types are used to build dynamic GraphQL schemas for querying +structured row data. +""" + +from typing import Optional, List +from enum import Enum + +import strawberry + + +@strawberry.input +class IntFilter: + """Filter type for integer fields.""" + eq: Optional[int] = None + gt: Optional[int] = None + gte: Optional[int] = None + lt: Optional[int] = None + lte: Optional[int] = None + in_: Optional[List[int]] = strawberry.field(name="in", default=None) + not_: Optional[int] = strawberry.field(name="not", default=None) + not_in: Optional[List[int]] = None + + +@strawberry.input +class StringFilter: + """Filter type for string fields.""" + eq: Optional[str] = None + contains: Optional[str] = None + startsWith: Optional[str] = None + endsWith: Optional[str] = None + in_: Optional[List[str]] = strawberry.field(name="in", default=None) + not_: Optional[str] = strawberry.field(name="not", default=None) + not_in: Optional[List[str]] = None + + +@strawberry.input +class FloatFilter: + """Filter type for float fields.""" + eq: Optional[float] = None + gt: Optional[float] = None + gte: Optional[float] = None + lt: Optional[float] = None + lte: Optional[float] = None + in_: Optional[List[float]] = strawberry.field(name="in", default=None) + not_: Optional[float] = strawberry.field(name="not", default=None) + not_in: Optional[List[float]] = None + + +@strawberry.enum +class SortDirection(Enum): + """Sort direction for query results.""" + ASC = "asc" + DESC = "desc" diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py b/trustgraph-flow/trustgraph/query/objects/cassandra/service.py deleted file mode 100644 index a6683c40..00000000 --- a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py +++ /dev/null @@ -1,738 +0,0 @@ -""" -Objects query service using GraphQL. Input is a GraphQL query with variables. -Output is GraphQL response data with any errors. -""" - -import json -import logging -import asyncio -from typing import Dict, Any, Optional, List, Set -from enum import Enum -from dataclasses import dataclass, field -from cassandra.cluster import Cluster -from cassandra.auth import PlainTextAuthProvider - -import strawberry -from strawberry import Schema -from strawberry.types import Info -from strawberry.scalars import JSON -from strawberry.tools import create_type - -from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError -from .... schema import Error, RowSchema, Field as SchemaField -from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config - -# Module logger -logger = logging.getLogger(__name__) - -default_ident = "objects-query" - -# GraphQL filter input types -@strawberry.input -class IntFilter: - eq: Optional[int] = None - gt: Optional[int] = None - gte: Optional[int] = None - lt: Optional[int] = None - lte: Optional[int] = None - in_: Optional[List[int]] = strawberry.field(name="in", default=None) - not_: Optional[int] = strawberry.field(name="not", default=None) - not_in: Optional[List[int]] = None - -@strawberry.input -class StringFilter: - eq: Optional[str] = None - contains: Optional[str] = None - startsWith: Optional[str] = None - endsWith: Optional[str] = None - in_: Optional[List[str]] = strawberry.field(name="in", default=None) - not_: Optional[str] = strawberry.field(name="not", default=None) - not_in: Optional[List[str]] = None - -@strawberry.input -class FloatFilter: - eq: Optional[float] = None - gt: Optional[float] = None - gte: Optional[float] = None - lt: Optional[float] = None - lte: Optional[float] = None - in_: Optional[List[float]] = strawberry.field(name="in", default=None) - not_: Optional[float] = strawberry.field(name="not", default=None) - not_in: Optional[List[float]] = None - - -class Processor(FlowProcessor): - - def __init__(self, **params): - - id = params.get("id", default_ident) - - # Get Cassandra parameters - cassandra_host = params.get("cassandra_host") - cassandra_username = params.get("cassandra_username") - cassandra_password = params.get("cassandra_password") - - # Resolve configuration with environment variable fallback - hosts, username, password, keyspace = resolve_cassandra_config( - host=cassandra_host, - username=cassandra_username, - password=cassandra_password - ) - - # Store resolved configuration with proper names - self.cassandra_host = hosts # Store as list - self.cassandra_username = username - self.cassandra_password = password - - # Config key for schemas - self.config_key = params.get("config_type", "schema") - - super(Processor, self).__init__( - **params | { - "id": id, - "config_type": self.config_key, - } - ) - - self.register_specification( - ConsumerSpec( - name = "request", - schema = ObjectsQueryRequest, - handler = self.on_message - ) - ) - - self.register_specification( - ProducerSpec( - name = "response", - schema = ObjectsQueryResponse, - ) - ) - - # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) - - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - - # GraphQL schema - self.graphql_schema: Optional[Schema] = None - - # GraphQL types cache - self.graphql_types: Dict[str, type] = {} - - # Cassandra session - self.cluster = None - self.session = None - - # Known keyspaces and tables - self.known_keyspaces: Set[str] = set() - self.known_tables: Dict[str, Set[str]] = {} - - def connect_cassandra(self): - """Connect to Cassandra cluster""" - if self.session: - return - - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) - - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise - - def sanitize_name(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - if safe_name and not safe_name[0].isalpha(): - safe_name = 'o_' + safe_name - return safe_name.lower() - - def sanitize_table(self, name: str) -> str: - """Sanitize table names for Cassandra compatibility""" - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - safe_name = 'o_' + safe_name - return safe_name.lower() - - def parse_filter_key(self, filter_key: str) -> tuple[str, str]: - """Parse GraphQL filter key into field name and operator""" - if not filter_key: - return ("", "eq") - - # Support common GraphQL filter patterns: - # field_name -> (field_name, "eq") - # field_name_gt -> (field_name, "gt") - # field_name_gte -> (field_name, "gte") - # field_name_lt -> (field_name, "lt") - # field_name_lte -> (field_name, "lte") - # field_name_in -> (field_name, "in") - - operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"] - - for op_suffix in operators: - if filter_key.endswith(op_suffix): - field_name = filter_key[:-len(op_suffix)] - operator = op_suffix[1:] # Remove the leading underscore - return (field_name, operator) - - # Default to equality if no operator suffix found - return (filter_key, "eq") - - async def on_schema_config(self, config, version): - """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - self.graphql_types = {} - - # Check if our config type exists - if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") - return - - # Get the schemas dictionary for our type - schemas_config = config[self.config_key] - - # Process each schema in the schemas config - for schema_name, schema_json in schemas_config.items(): - try: - # Parse the JSON schema definition - schema_def = json.loads(schema_json) - - # Create Field objects - fields = [] - for field_def in schema_def.get("fields", []): - field = SchemaField( - name=field_def["name"], - type=field_def["type"], - size=field_def.get("size", 0), - primary=field_def.get("primary_key", False), - description=field_def.get("description", ""), - required=field_def.get("required", False), - enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) - ) - fields.append(field) - - # Create RowSchema - row_schema = RowSchema( - name=schema_def.get("name", schema_name), - description=schema_def.get("description", ""), - fields=fields - ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - - except Exception as e: - logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") - - # Regenerate GraphQL schema - self.generate_graphql_schema() - - def get_python_type(self, field_type: str): - """Convert schema field type to Python type for GraphQL""" - type_mapping = { - "string": str, - "integer": int, - "float": float, - "boolean": bool, - "timestamp": str, # Use string for timestamps in GraphQL - "date": str, - "time": str, - "uuid": str - } - return type_mapping.get(field_type, str) - - def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type: - """Create a GraphQL type from a RowSchema""" - - # Create annotations for the GraphQL type - annotations = {} - defaults = {} - - for field in row_schema.fields: - python_type = self.get_python_type(field.type) - - # Make field optional if not required - if not field.required and not field.primary: - annotations[field.name] = Optional[python_type] - defaults[field.name] = None - else: - annotations[field.name] = python_type - - # Create the class dynamically - type_name = f"{schema_name.capitalize()}Type" - graphql_class = type( - type_name, - (), - { - "__annotations__": annotations, - **defaults - } - ) - - # Apply strawberry decorator - return strawberry.type(graphql_class) - - def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema): - """Create a dynamic filter input type for a schema""" - # Create the filter type dynamically - filter_type_name = f"{schema_name.capitalize()}Filter" - - # Add __annotations__ and defaults for the fields - annotations = {} - defaults = {} - - logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}") - - for field in row_schema.fields: - logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}") - - # Allow filtering on any field for now, not just indexed/primary - # if field.indexed or field.primary: - if field.type == "integer": - annotations[field.name] = Optional[IntFilter] - defaults[field.name] = None - logger.info(f"Added IntFilter for {field.name}") - elif field.type == "float": - annotations[field.name] = Optional[FloatFilter] - defaults[field.name] = None - logger.info(f"Added FloatFilter for {field.name}") - elif field.type == "string": - annotations[field.name] = Optional[StringFilter] - defaults[field.name] = None - logger.info(f"Added StringFilter for {field.name}") - - logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}") - - # Create the class dynamically - FilterType = type( - filter_type_name, - (), - { - "__annotations__": annotations, - **defaults - } - ) - - # Apply strawberry input decorator - FilterType = strawberry.input(FilterType) - - return FilterType - - def create_sort_direction_enum(self): - """Create sort direction enum""" - @strawberry.enum - class SortDirection(Enum): - ASC = "asc" - DESC = "desc" - - return SortDirection - - def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]: - """Parse the idiomatic nested filter structure""" - if not where_obj: - return {} - - conditions = {} - - logger.info(f"Parsing where clause: {where_obj}") - - for field_name, filter_obj in where_obj.__dict__.items(): - if filter_obj is None: - continue - - logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}") - - if hasattr(filter_obj, '__dict__'): - # This is a filter object (StringFilter, IntFilter, etc.) - for operator, value in filter_obj.__dict__.items(): - if value is not None: - logger.info(f"Found operator {operator} with value {value}") - # Map GraphQL operators to our internal format - if operator == "eq": - conditions[field_name] = value - elif operator in ["gt", "gte", "lt", "lte"]: - conditions[f"{field_name}_{operator}"] = value - elif operator == "in_": - conditions[f"{field_name}_in"] = value - elif operator == "contains": - conditions[f"{field_name}_contains"] = value - - logger.info(f"Final parsed conditions: {conditions}") - return conditions - - def generate_graphql_schema(self): - """Generate GraphQL schema from loaded schemas using dynamic filter types""" - if not self.schemas: - logger.warning("No schemas loaded, cannot generate GraphQL schema") - self.graphql_schema = None - return - - # Create GraphQL types and filter types for each schema - filter_types = {} - sort_direction_enum = self.create_sort_direction_enum() - - for schema_name, row_schema in self.schemas.items(): - graphql_type = self.create_graphql_type(schema_name, row_schema) - filter_type = self.create_filter_type_for_schema(schema_name, row_schema) - - self.graphql_types[schema_name] = graphql_type - filter_types[schema_name] = filter_type - - # Create the Query class with resolvers - query_dict = {'__annotations__': {}} - - for schema_name, row_schema in self.schemas.items(): - graphql_type = self.graphql_types[schema_name] - filter_type = filter_types[schema_name] - - # Create resolver function for this schema - def make_resolver(s_name, r_schema, g_type, f_type, sort_enum): - async def resolver( - info: Info, - where: Optional[f_type] = None, - order_by: Optional[str] = None, - direction: Optional[sort_enum] = None, - limit: Optional[int] = 100 - ) -> List[g_type]: - # Get the processor instance from context - processor = info.context["processor"] - user = info.context["user"] - collection = info.context["collection"] - - # Parse the idiomatic where clause - filters = processor.parse_idiomatic_where_clause(where) - - # Query Cassandra - results = await processor.query_cassandra( - user, collection, s_name, r_schema, - filters, limit, order_by, direction - ) - - # Convert to GraphQL types - graphql_results = [] - for row in results: - graphql_obj = g_type(**row) - graphql_results.append(graphql_obj) - - return graphql_results - - return resolver - - # Add resolver to query - resolver_name = schema_name - resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum) - - # Add field to query dictionary - query_dict[resolver_name] = strawberry.field(resolver=resolver_func) - query_dict['__annotations__'][resolver_name] = List[graphql_type] - - # Create the Query class - Query = type('Query', (), query_dict) - Query = strawberry.type(Query) - - # Create the schema with auto_camel_case disabled to keep snake_case field names - self.graphql_schema = strawberry.Schema( - query=Query, - config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False) - ) - logger.info(f"Generated GraphQL schema with {len(self.schemas)} types") - - async def query_cassandra( - self, - user: str, - collection: str, - schema_name: str, - row_schema: RowSchema, - filters: Dict[str, Any], - limit: int, - order_by: Optional[str] = None, - direction: Optional[Any] = None - ) -> List[Dict[str, Any]]: - """Execute a query against Cassandra""" - - # Connect if needed - self.connect_cassandra() - - # Build the query - keyspace = self.sanitize_name(user) - table = self.sanitize_table(schema_name) - - # Start with basic SELECT - query = f"SELECT * FROM {keyspace}.{table}" - - # Add WHERE clauses - where_clauses = [f"collection = %s"] - params = [collection] - - # Add filters for indexed or primary key fields - for filter_key, value in filters.items(): - if value is not None: - # Parse field name and operator from filter key - logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})") - result = self.parse_filter_key(filter_key) - logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})") - - if not result or len(result) != 2: - logger.error(f"parse_filter_key returned invalid result: {result}") - continue # Skip this filter - - field_name, operator = result - - # Find the field in schema - schema_field = None - for f in row_schema.fields: - if f.name == field_name: - schema_field = f - break - - if schema_field: - safe_field = self.sanitize_name(field_name) - - # Build WHERE clause based on operator - if operator == "eq": - where_clauses.append(f"{safe_field} = %s") - params.append(value) - elif operator == "gt": - where_clauses.append(f"{safe_field} > %s") - params.append(value) - elif operator == "gte": - where_clauses.append(f"{safe_field} >= %s") - params.append(value) - elif operator == "lt": - where_clauses.append(f"{safe_field} < %s") - params.append(value) - elif operator == "lte": - where_clauses.append(f"{safe_field} <= %s") - params.append(value) - elif operator == "in": - if isinstance(value, list): - placeholders = ",".join(["%s"] * len(value)) - where_clauses.append(f"{safe_field} IN ({placeholders})") - params.extend(value) - else: - # Default to equality for unknown operators - where_clauses.append(f"{safe_field} = %s") - params.append(value) - - if where_clauses: - query += " WHERE " + " AND ".join(where_clauses) - - # Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort) - cassandra_order_by_added = False - if order_by and direction: - # Validate that order_by field exists in schema - order_field_exists = any(f.name == order_by for f in row_schema.fields) - if order_field_exists: - safe_order_field = self.sanitize_name(order_by) - direction_str = "ASC" if direction.value == "asc" else "DESC" - # Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution - query += f" ORDER BY {safe_order_field} {direction_str}" - - # Add limit first (must come before ALLOW FILTERING) - if limit: - query += f" LIMIT {limit}" - - # Add ALLOW FILTERING for now (should optimize with proper indexes later) - query += " ALLOW FILTERING" - - # Execute query - try: - result = self.session.execute(query, params) - cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY - except Exception as e: - # If ORDER BY fails, try without it - if order_by and direction and "ORDER BY" in query: - logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}") - # Remove ORDER BY clause and retry - query_parts = query.split(" ORDER BY ") - if len(query_parts) == 2: - query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING" - result = self.session.execute(query_without_order, params) - cassandra_order_by_added = False - else: - raise - else: - raise - - # Convert rows to dicts - results = [] - for row in result: - row_dict = {} - for field in row_schema.fields: - safe_field = self.sanitize_name(field.name) - if hasattr(row, safe_field): - value = getattr(row, safe_field) - # Use original field name in result - row_dict[field.name] = value - results.append(row_dict) - - # Post-query sorting if Cassandra didn't handle ORDER BY - if order_by and direction and not cassandra_order_by_added: - reverse_order = (direction.value == "desc") - try: - results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order) - except Exception as e: - logger.warning(f"Failed to sort results by {order_by}: {e}") - - return results - - async def execute_graphql_query( - self, - query: str, - variables: Dict[str, Any], - operation_name: Optional[str], - user: str, - collection: str - ) -> Dict[str, Any]: - """Execute a GraphQL query""" - - if not self.graphql_schema: - raise RuntimeError("No GraphQL schema available - no schemas loaded") - - # Create context for the query - context = { - "processor": self, - "user": user, - "collection": collection - } - - # Execute the query - result = await self.graphql_schema.execute( - query, - variable_values=variables, - operation_name=operation_name, - context_value=context - ) - - # Build response - response = {} - - if result.data: - response["data"] = result.data - else: - response["data"] = None - - if result.errors: - response["errors"] = [ - { - "message": str(error), - "path": getattr(error, "path", []), - "extensions": getattr(error, "extensions", {}) - } - for error in result.errors - ] - else: - response["errors"] = [] - - # Add extensions if any - if hasattr(result, "extensions") and result.extensions: - response["extensions"] = result.extensions - - return response - - async def on_message(self, msg, consumer, flow): - """Handle incoming query request""" - - try: - request = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - logger.debug(f"Handling objects query request {id}...") - - # Execute GraphQL query - result = await self.execute_graphql_query( - query=request.query, - variables=dict(request.variables) if request.variables else {}, - operation_name=request.operation_name, - user=request.user, - collection=request.collection - ) - - # Create response - graphql_errors = [] - if "errors" in result and result["errors"]: - for err in result["errors"]: - graphql_error = GraphQLError( - message=err.get("message", ""), - path=err.get("path", []), - extensions=err.get("extensions", {}) - ) - graphql_errors.append(graphql_error) - - response = ObjectsQueryResponse( - error=None, - data=json.dumps(result.get("data")) if result.get("data") else "null", - errors=graphql_errors, - extensions=result.get("extensions", {}) - ) - - logger.debug("Sending objects query response...") - await flow("response").send(response, properties={"id": id}) - - logger.debug("Objects query request completed") - - except Exception as e: - - logger.error(f"Exception in objects query service: {e}", exc_info=True) - - logger.info("Sending error response...") - - response = ObjectsQueryResponse( - error = Error( - type = "objects-query-error", - message = str(e), - ), - data = None, - errors = [], - extensions = {} - ) - - await flow("response").send(response, properties={"id": id}) - - def close(self): - """Clean up Cassandra connections""" - if self.cluster: - self.cluster.shutdown() - logger.info("Closed Cassandra connection") - - @staticmethod - def add_args(parser): - """Add command-line arguments""" - - FlowProcessor.add_args(parser) - add_cassandra_args(parser) - - parser.add_argument( - '--config-type', - default='schema', - help='Configuration type prefix for schemas (default: schema)' - ) - -def run(): - """Entry point for objects-query-graphql-cassandra command""" - Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py new file mode 100644 index 00000000..6c6391f5 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py @@ -0,0 +1,3 @@ +""" +Row embeddings query modules. +""" diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py new file mode 100644 index 00000000..a4ca1c85 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py @@ -0,0 +1,5 @@ +""" +Qdrant row embeddings query service. +""" + +from .service import Processor, run, default_ident diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py new file mode 100644 index 00000000..66f42e76 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py @@ -0,0 +1,4 @@ + +from .service import run + +run() diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py new file mode 100644 index 00000000..7ed6192f --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -0,0 +1,209 @@ +""" +Row embeddings query service for Qdrant. + +Input is query vectors plus user/collection/schema context. +Output is matching row index information (index_name, index_value) for +use in subsequent Cassandra lookups. +""" + +import logging +import re +from typing import Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue + +from .... schema import ( + RowEmbeddingsRequest, RowEmbeddingsResponse, + RowIndexMatch, Error +) +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings-query" +default_store_uri = 'http://localhost:6333' + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) + + super(Processor, self).__init__( + **params | { + "id": id, + "store_uri": store_uri, + "api_key": api_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=RowEmbeddingsRequest, + handler=self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=RowEmbeddingsResponse + ) + ) + + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Qdrant collection naming""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]: + """Find the Qdrant collection for a given user/collection/schema""" + prefix = ( + f"rows_{self.sanitize_name(user)}_" + f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" + ) + + try: + all_collections = self.qdrant.get_collections().collections + matching = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if matching: + # Return first match (there should typically be only one per dimension) + return matching[0] + + except Exception as e: + logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True) + + return None + + async def query_row_embeddings(self, request: RowEmbeddingsRequest): + """Execute row embeddings query""" + + matches = [] + + # Find the collection for this user/collection/schema + qdrant_collection = self.find_collection( + request.user, request.collection, request.schema_name + ) + + if not qdrant_collection: + logger.info( + f"No Qdrant collection found for " + f"{request.user}/{request.collection}/{request.schema_name}" + ) + return matches + + for vec in request.vectors: + try: + # Build optional filter for index_name + query_filter = None + if request.index_name: + query_filter = Filter( + must=[ + FieldCondition( + key="index_name", + match=MatchValue(value=request.index_name) + ) + ] + ) + + # Query Qdrant + search_result = self.qdrant.query_points( + collection_name=qdrant_collection, + query=vec, + limit=request.limit, + with_payload=True, + query_filter=query_filter, + ).points + + # Convert to RowIndexMatch objects + for point in search_result: + payload = point.payload or {} + match = RowIndexMatch( + index_name=payload.get("index_name", ""), + index_value=payload.get("index_value", []), + text=payload.get("text", ""), + score=point.score if hasattr(point, 'score') else 0.0 + ) + matches.append(match) + + except Exception as e: + logger.error(f"Failed to query Qdrant: {e}", exc_info=True) + raise + + return matches + + async def on_message(self, msg, consumer, flow): + """Handle incoming query request""" + + try: + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.debug( + f"Handling row embeddings query for " + f"{request.user}/{request.collection}/{request.schema_name}..." + ) + + # Execute query + matches = await self.query_row_embeddings(request) + + response = RowEmbeddingsResponse( + error=None, + matches=matches + ) + + logger.debug(f"Returning {len(matches)} matches") + await flow("response").send(response, properties={"id": id}) + + except Exception as e: + logger.error(f"Exception in row embeddings query: {e}", exc_info=True) + + response = RowEmbeddingsResponse( + error=Error( + type="row-embeddings-query-error", + message=str(e) + ), + matches=[] + ) + + await flow("response").send(response, properties={"id": id}) + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + + parser.add_argument( + '-t', '--store-uri', + default=default_store_uri, + help=f'Qdrant store URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help='API key for Qdrant (default: None)' + ) + + +def run(): + """Entry point for row-embeddings-query-qdrant command""" + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/objects/__init__.py b/trustgraph-flow/trustgraph/query/rows/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/__init__.py rename to trustgraph-flow/trustgraph/query/rows/__init__.py diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/__init__.py b/trustgraph-flow/trustgraph/query/rows/cassandra/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/cassandra/__init__.py rename to trustgraph-flow/trustgraph/query/rows/cassandra/__init__.py diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/__main__.py b/trustgraph-flow/trustgraph/query/rows/cassandra/__main__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/cassandra/__main__.py rename to trustgraph-flow/trustgraph/query/rows/cassandra/__main__.py diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py new file mode 100644 index 00000000..3808cdb0 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -0,0 +1,523 @@ +""" +Row query service using GraphQL. Input is a GraphQL query with variables. +Output is GraphQL response data with any errors. + +Queries against the unified 'rows' table with schema: + - collection: text + - schema_name: text + - index_name: text + - index_value: frozen> + - data: map + - source: text +""" + +import json +import logging +import re +from typing import Dict, Any, Optional, List, Set + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider + +from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError +from .... schema import Error, RowSchema, Field as SchemaField +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config + +from ... graphql import GraphQLSchemaBuilder, SortDirection + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "rows-query" + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + # Get Cassandra parameters + cassandra_host = params.get("cassandra_host") + cassandra_username = params.get("cassandra_username") + cassandra_password = params.get("cassandra_password") + + # Resolve configuration with environment variable fallback + hosts, username, password, keyspace = resolve_cassandra_config( + host=cassandra_host, + username=cassandra_username, + password=cassandra_password + ) + + # Store resolved configuration with proper names + self.cassandra_host = hosts # Store as list + self.cassandra_username = username + self.cassandra_password = password + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + + super(Processor, self).__init__( + **params | { + "id": id, + "config_type": self.config_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=RowsQueryRequest, + handler=self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=RowsQueryResponse, + ) + ) + + # Register config handler for schema updates + self.register_config_handler(self.on_schema_config) + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + # GraphQL schema builder and generated schema + self.schema_builder = GraphQLSchemaBuilder() + self.graphql_schema = None + + # Cassandra session + self.cluster = None + self.session = None + + # Known keyspaces + self.known_keyspaces: Set[str] = set() + + def connect_cassandra(self): + """Connect to Cassandra cluster""" + if self.session: + return + + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + self.cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + self.cluster = Cluster(contact_points=self.cassandra_host) + + self.session = self.cluster.connect() + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Cassandra compatibility""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") + + # Clear existing schemas + self.schemas = {} + self.schema_builder.clear() + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = SchemaField( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + self.schema_builder.add_schema(schema_name, row_schema) + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + # Regenerate GraphQL schema + self.graphql_schema = self.schema_builder.build(self.query_cassandra) + + def get_index_names(self, schema: RowSchema) -> List[str]: + """Get all index names for a schema.""" + index_names = [] + for field in schema.fields: + if field.primary or field.indexed: + index_names.append(field.name) + return index_names + + def find_matching_index( + self, + schema: RowSchema, + filters: Dict[str, Any] + ) -> Optional[tuple]: + """ + Find an index that can satisfy the query filters. + Returns (index_name, index_value) if found, None otherwise. + + For exact match queries, we need a filter on an indexed field. + """ + index_names = self.get_index_names(schema) + + # Look for an exact match filter on an indexed field + for index_name in index_names: + if index_name in filters: + value = filters[index_name] + # Single field index -> single element list + index_value = [str(value)] + return (index_name, index_value) + + return None + + async def query_cassandra( + self, + user: str, + collection: str, + schema_name: str, + row_schema: RowSchema, + filters: Dict[str, Any], + limit: int, + order_by: Optional[str] = None, + direction: Optional[SortDirection] = None + ) -> List[Dict[str, Any]]: + """ + Execute a query against the unified Cassandra rows table. + + For exact match queries on indexed fields, we can query directly. + For other queries, we need to scan and post-filter. + """ + # Connect if needed + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Try to find an index that matches the filters + index_match = self.find_matching_index(row_schema, filters) + + results = [] + + if index_match: + # Direct query using index + index_name, index_value = index_match + + query = f""" + SELECT data, source FROM {safe_keyspace}.rows + WHERE collection = %s + AND schema_name = %s + AND index_name = %s + AND index_value = %s + """ + params = [collection, schema_name, index_name, index_value] + + if limit: + query += f" LIMIT {limit}" + + try: + rows = self.session.execute(query, params) + for row in rows: + # Convert data map to dict with proper field names + row_dict = dict(row.data) if row.data else {} + results.append(row_dict) + except Exception as e: + logger.error(f"Failed to query rows: {e}", exc_info=True) + raise + + else: + # No direct index match - scan all rows for this schema + # This is less efficient but necessary for non-indexed queries + logger.warning( + f"No index match for filters {filters} - scanning all indexes" + ) + + # Get all index names for this schema + index_names = self.get_index_names(row_schema) + + if not index_names: + logger.warning(f"Schema {schema_name} has no indexes") + return [] + + # Query using the first index (arbitrary choice for scan) + primary_index = index_names[0] + + # We need to scan all values for this index + # This requires ALLOW FILTERING or a different approach + query = f""" + SELECT data, source FROM {safe_keyspace}.rows + WHERE collection = %s + AND schema_name = %s + AND index_name = %s + ALLOW FILTERING + """ + params = [collection, schema_name, primary_index] + + try: + rows = self.session.execute(query, params) + + for row in rows: + row_dict = dict(row.data) if row.data else {} + + # Apply post-filters + if self._matches_filters(row_dict, filters, row_schema): + results.append(row_dict) + + if limit and len(results) >= limit: + break + + except Exception as e: + logger.error(f"Failed to scan rows: {e}", exc_info=True) + raise + + # Post-query sorting if requested + if order_by and results: + reverse_order = direction and direction.value == "desc" + try: + results.sort( + key=lambda x: x.get(order_by, ""), + reverse=reverse_order + ) + except Exception as e: + logger.warning(f"Failed to sort results by {order_by}: {e}") + + return results + + def _matches_filters( + self, + row_dict: Dict[str, Any], + filters: Dict[str, Any], + row_schema: RowSchema + ) -> bool: + """Check if a row matches the given filters.""" + for filter_key, filter_value in filters.items(): + if filter_value is None: + continue + + # Parse filter key for operator + if '_' in filter_key: + parts = filter_key.rsplit('_', 1) + if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']: + field_name = parts[0] + operator = parts[1] + else: + field_name = filter_key + operator = 'eq' + else: + field_name = filter_key + operator = 'eq' + + row_value = row_dict.get(field_name) + if row_value is None: + return False + + # Convert types for comparison + try: + if operator == 'eq': + if str(row_value) != str(filter_value): + return False + elif operator == 'gt': + if float(row_value) <= float(filter_value): + return False + elif operator == 'gte': + if float(row_value) < float(filter_value): + return False + elif operator == 'lt': + if float(row_value) >= float(filter_value): + return False + elif operator == 'lte': + if float(row_value) > float(filter_value): + return False + elif operator == 'contains': + if str(filter_value) not in str(row_value): + return False + elif operator == 'in': + if str(row_value) not in [str(v) for v in filter_value]: + return False + except (ValueError, TypeError): + return False + + return True + + async def execute_graphql_query( + self, + query: str, + variables: Dict[str, Any], + operation_name: Optional[str], + user: str, + collection: str + ) -> Dict[str, Any]: + """Execute a GraphQL query""" + + if not self.graphql_schema: + raise RuntimeError("No GraphQL schema available - no schemas loaded") + + # Create context for the query + context = { + "processor": self, + "user": user, + "collection": collection + } + + # Execute the query + result = await self.graphql_schema.execute( + query, + variable_values=variables, + operation_name=operation_name, + context_value=context + ) + + # Build response + response = {} + + if result.data: + response["data"] = result.data + else: + response["data"] = None + + if result.errors: + response["errors"] = [ + { + "message": str(error), + "path": getattr(error, "path", []), + "extensions": getattr(error, "extensions", {}) + } + for error in result.errors + ] + else: + response["errors"] = [] + + # Add extensions if any + if hasattr(result, "extensions") and result.extensions: + response["extensions"] = result.extensions + + return response + + async def on_message(self, msg, consumer, flow): + """Handle incoming query request""" + + try: + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.debug(f"Handling objects query request {id}...") + + # Execute GraphQL query + result = await self.execute_graphql_query( + query=request.query, + variables=dict(request.variables) if request.variables else {}, + operation_name=request.operation_name, + user=request.user, + collection=request.collection + ) + + # Create response + graphql_errors = [] + if "errors" in result and result["errors"]: + for err in result["errors"]: + graphql_error = GraphQLError( + message=err.get("message", ""), + path=err.get("path", []), + extensions=err.get("extensions", {}) + ) + graphql_errors.append(graphql_error) + + response = RowsQueryResponse( + error=None, + data=json.dumps(result.get("data")) if result.get("data") else "null", + errors=graphql_errors, + extensions=result.get("extensions", {}) + ) + + logger.debug("Sending objects query response...") + await flow("response").send(response, properties={"id": id}) + + logger.debug("Objects query request completed") + + except Exception as e: + + logger.error(f"Exception in rows query service: {e}", exc_info=True) + + logger.info("Sending error response...") + + response = RowsQueryResponse( + error=Error( + type="rows-query-error", + message=str(e), + ), + data=None, + errors=[], + extensions={} + ) + + await flow("response").send(response, properties={"id": id}) + + def close(self): + """Clean up Cassandra connections""" + if self.cluster: + self.cluster.shutdown() + logger.info("Closed Cassandra connection") + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + add_cassandra_args(parser) + + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + +def run(): + """Entry point for rows-query-cassandra command""" + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 13726ac3..eac33dde 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -1,14 +1,16 @@ """ -Triples query service. Input is a (s, p, o) triple, some values may be -null. Output is a list of triples. +Triples query service. Input is a (s, p, o, g) quad pattern, some values may be +null. Output is a list of quads. """ import logging -from .... direct.cassandra_kg import KnowledgeGraph +from .... direct.cassandra_kg import ( + EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH +) from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Value, Triple +from .... schema import Term, Triple, IRI, LITERAL from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config @@ -18,6 +20,56 @@ logger = logging.getLogger(__name__) default_ident = "triples-query" +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + +def create_term(value, otype=None, dtype=None, lang=None): + """ + Create a Term from a string value, optionally using type metadata. + + Args: + value: The string value + otype: Object type - 'u' (URI), 'l' (literal), 't' (triple) + dtype: XSD datatype (for literals) + lang: Language tag (for literals) + + If otype is provided, uses it to determine Term type. + Otherwise falls back to URL detection heuristic. + """ + if otype is not None: + if otype == 'u': + return Term(type=IRI, iri=value) + elif otype == 'l': + return Term( + type=LITERAL, + value=value, + datatype=dtype or "", + language=lang or "" + ) + elif otype == 't': + # Triple/reification - treat as IRI for now + return Term(type=IRI, iri=value) + else: + # Unknown otype, fall back to heuristic + pass + + # Heuristic fallback for backwards compatibility + if value.startswith("http://") or value.startswith("https://"): + return Term(type=IRI, iri=value) + else: + return Term(type=LITERAL, value=value) + + class Processor(TriplesQueryService): def __init__(self, **params): @@ -46,12 +98,6 @@ class Processor(TriplesQueryService): self.cassandra_password = password self.table = None - def create_value(self, ent): - if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) - else: - return Value(value=ent, is_uri=False) - async def query_triples(self, query): try: @@ -59,90 +105,137 @@ class Processor(TriplesQueryService): user = query.user if user != self.table: + # Use factory function to select implementation + KGClass = EntityCentricKnowledgeGraph + if self.cassandra_username and self.cassandra_password: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=query.user, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=query.user, ) self.table = user - triples = [] + # Extract values from query + s_val = get_term_value(query.s) + p_val = get_term_value(query.p) + o_val = get_term_value(query.o) + g_val = query.g # Already a string or None - if query.s is not None: - if query.p is not None: - if query.o is not None: + # Helper to extract object metadata from result row + def get_o_metadata(t): + """Extract otype/dtype/lang from result row if available""" + otype = getattr(t, 'otype', None) + dtype = getattr(t, 'dtype', None) + lang = getattr(t, 'lang', None) + return otype, dtype, lang + + quads = [] + + # Route to appropriate query method based on which fields are specified + if s_val is not None: + if p_val is not None: + if o_val is not None: + # SPO specified - find matching graphs resp = self.tg.get_spo( - query.collection, query.s.value, query.p.value, query.o.value, + query.collection, s_val, p_val, o_val, g=g_val, limit=query.limit ) - triples.append((query.s.value, query.p.value, query.o.value)) + for t in resp: + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((s_val, p_val, o_val, g, otype, dtype, lang)) else: + # SP specified resp = self.tg.get_sp( - query.collection, query.s.value, query.p.value, + query.collection, s_val, p_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((query.s.value, query.p.value, t.o)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((s_val, p_val, t.o, g, otype, dtype, lang)) else: - if query.o is not None: + if o_val is not None: + # SO specified resp = self.tg.get_os( - query.collection, query.o.value, query.s.value, + query.collection, o_val, s_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((query.s.value, t.p, query.o.value)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((s_val, t.p, o_val, g, otype, dtype, lang)) else: + # S only resp = self.tg.get_s( - query.collection, query.s.value, + query.collection, s_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((query.s.value, t.p, t.o)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((s_val, t.p, t.o, g, otype, dtype, lang)) else: - if query.p is not None: - if query.o is not None: + if p_val is not None: + if o_val is not None: + # PO specified resp = self.tg.get_po( - query.collection, query.p.value, query.o.value, + query.collection, p_val, o_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((t.s, query.p.value, query.o.value)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((t.s, p_val, o_val, g, otype, dtype, lang)) else: + # P only resp = self.tg.get_p( - query.collection, query.p.value, + query.collection, p_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((t.s, query.p.value, t.o)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((t.s, p_val, t.o, g, otype, dtype, lang)) else: - if query.o is not None: + if o_val is not None: + # O only resp = self.tg.get_o( - query.collection, query.o.value, + query.collection, o_val, g=g_val, limit=query.limit ) for t in resp: - triples.append((t.s, t.p, query.o.value)) + g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((t.s, t.p, o_val, g, otype, dtype, lang)) else: + # Nothing specified - get all resp = self.tg.get_all( query.collection, limit=query.limit ) for t in resp: - triples.append((t.s, t.p, t.o)) + # Note: quads_by_collection uses 'd' for graph field + g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(t) + quads.append((t.s, t.p, t.o, g, otype, dtype, lang)) + # Convert to Triple objects (with g field) + # Use otype/dtype/lang for proper Term reconstruction if available triples = [ Triple( - s=self.create_value(t[0]), - p=self.create_value(t[1]), - o=self.create_value(t[2]) + s=create_term(q[0]), + p=create_term(q[1]), + o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]), + g=q[3] if q[3] != DEFAULT_GRAPH else None ) - for t in triples + for q in quads ] return triples @@ -162,4 +255,3 @@ class Processor(TriplesQueryService): def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index d1c7be7d..14b24d52 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -10,12 +10,24 @@ import logging from falkordb import FalkorDB from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Value, Triple +from .... schema import Term, Triple, IRI, LITERAL from .... base import TriplesQueryService # Module logger logger = logging.getLogger(__name__) + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + return term.id or term.value + default_ident = "triples-query" default_graph_url = 'falkor://falkordb:6379' @@ -42,9 +54,9 @@ class Processor(TriplesQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_triples(self, query): @@ -63,28 +75,28 @@ class Processor(TriplesQueryService): "RETURN $src as src " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "rel": query.p.value, - "value": query.o.value, + "src": get_term_value(query.s), + "rel": get_term_value(query.p), + "value": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " "RETURN $src as src " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "rel": query.p.value, - "uri": query.o.value, + "src": get_term_value(query.s), + "rel": get_term_value(query.p), + "uri": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) else: @@ -95,26 +107,26 @@ class Processor(TriplesQueryService): "RETURN dest.value as dest " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "rel": query.p.value, + "src": get_term_value(query.s), + "rel": get_term_value(query.p), }, ).result_set for rec in records: - triples.append((query.s.value, query.p.value, rec[0])) + triples.append((get_term_value(query.s), get_term_value(query.p), rec[0])) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "rel": query.p.value, + "src": get_term_value(query.s), + "rel": get_term_value(query.p), }, ).result_set for rec in records: - triples.append((query.s.value, query.p.value, rec[0])) + triples.append((get_term_value(query.s), get_term_value(query.p), rec[0])) else: @@ -127,26 +139,26 @@ class Processor(TriplesQueryService): "RETURN rel.uri as rel " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "value": query.o.value, + "src": get_term_value(query.s), + "value": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((query.s.value, rec[0], query.o.value)) + triples.append((get_term_value(query.s), rec[0], get_term_value(query.o))) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), params={ - "src": query.s.value, - "uri": query.o.value, + "src": get_term_value(query.s), + "uri": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((query.s.value, rec[0], query.o.value)) + triples.append((get_term_value(query.s), rec[0], get_term_value(query.o))) else: @@ -157,24 +169,24 @@ class Processor(TriplesQueryService): "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), params={ - "src": query.s.value, + "src": get_term_value(query.s), }, ).result_set for rec in records: - triples.append((query.s.value, rec[0], rec[1])) + triples.append((get_term_value(query.s), rec[0], rec[1])) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), params={ - "src": query.s.value, + "src": get_term_value(query.s), }, ).result_set for rec in records: - triples.append((query.s.value, rec[0], rec[1])) + triples.append((get_term_value(query.s), rec[0], rec[1])) else: @@ -190,26 +202,26 @@ class Processor(TriplesQueryService): "RETURN src.uri as src " "LIMIT " + str(query.limit), params={ - "uri": query.p.value, - "value": query.o.value, + "uri": get_term_value(query.p), + "value": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((rec[0], query.p.value, query.o.value)) + triples.append((rec[0], get_term_value(query.p), get_term_value(query.o))) records = self.io.query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), params={ - "uri": query.p.value, - "dest": query.o.value, + "uri": get_term_value(query.p), + "dest": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((rec[0], query.p.value, query.o.value)) + triples.append((rec[0], get_term_value(query.p), get_term_value(query.o))) else: @@ -220,24 +232,24 @@ class Processor(TriplesQueryService): "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), params={ - "uri": query.p.value, + "uri": get_term_value(query.p), }, ).result_set for rec in records: - triples.append((rec[0], query.p.value, rec[1])) + triples.append((rec[0], get_term_value(query.p), rec[1])) records = self.io.query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), params={ - "uri": query.p.value, + "uri": get_term_value(query.p), }, ).result_set for rec in records: - triples.append((rec[0], query.p.value, rec[1])) + triples.append((rec[0], get_term_value(query.p), rec[1])) else: @@ -250,24 +262,24 @@ class Processor(TriplesQueryService): "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), params={ - "value": query.o.value, + "value": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((rec[0], rec[1], query.o.value)) + triples.append((rec[0], rec[1], get_term_value(query.o))) records = self.io.query( "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), params={ - "uri": query.o.value, + "uri": get_term_value(query.o), }, ).result_set for rec in records: - triples.append((rec[0], rec[1], query.o.value)) + triples.append((rec[0], rec[1], get_term_value(query.o))) else: diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 262f89ab..37633f34 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -10,12 +10,24 @@ import logging from neo4j import GraphDatabase from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Value, Triple +from .... schema import Term, Triple, IRI, LITERAL from .... base import TriplesQueryService # Module logger logger = logging.getLogger(__name__) + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + return term.id or term.value + default_ident = "triples-query" default_graph_host = 'bolt://memgraph:7687' @@ -47,9 +59,9 @@ class Processor(TriplesQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_triples(self, query): @@ -73,13 +85,13 @@ class Processor(TriplesQueryService): "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), - src=query.s.value, rel=query.p.value, value=query.o.value, + src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" @@ -87,13 +99,13 @@ class Processor(TriplesQueryService): "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), - src=query.s.value, rel=query.p.value, uri=query.o.value, + src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) else: @@ -105,14 +117,14 @@ class Processor(TriplesQueryService): "(dest:Literal {user: $user, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), - src=query.s.value, rel=query.p.value, + src=get_term_value(query.s), rel=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, query.p.value, data["dest"])) + triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" @@ -120,14 +132,14 @@ class Processor(TriplesQueryService): "(dest:Node {user: $user, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), - src=query.s.value, rel=query.p.value, + src=get_term_value(query.s), rel=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, query.p.value, data["dest"])) + triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) else: @@ -141,14 +153,14 @@ class Processor(TriplesQueryService): "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), - src=query.s.value, value=query.o.value, + src=get_term_value(query.s), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], query.o.value)) + triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" @@ -156,14 +168,14 @@ class Processor(TriplesQueryService): "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), - src=query.s.value, uri=query.o.value, + src=get_term_value(query.s), uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], query.o.value)) + triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) else: @@ -175,14 +187,14 @@ class Processor(TriplesQueryService): "(dest:Literal {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), - src=query.s.value, + src=get_term_value(query.s), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], data["dest"])) + triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" @@ -190,14 +202,14 @@ class Processor(TriplesQueryService): "(dest:Node {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), - src=query.s.value, + src=get_term_value(query.s), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], data["dest"])) + triples.append((get_term_value(query.s), data["rel"], data["dest"])) else: @@ -214,14 +226,14 @@ class Processor(TriplesQueryService): "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), - uri=query.p.value, value=query.o.value, + uri=get_term_value(query.p), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, query.o.value)) + triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" @@ -229,14 +241,14 @@ class Processor(TriplesQueryService): "(dest:Node {uri: $dest, user: $user, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), - uri=query.p.value, dest=query.o.value, + uri=get_term_value(query.p), dest=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, query.o.value)) + triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) else: @@ -248,14 +260,14 @@ class Processor(TriplesQueryService): "(dest:Literal {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), - uri=query.p.value, + uri=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, data["dest"])) + triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" @@ -263,14 +275,14 @@ class Processor(TriplesQueryService): "(dest:Node {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), - uri=query.p.value, + uri=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, data["dest"])) + triples.append((data["src"], get_term_value(query.p), data["dest"])) else: @@ -284,14 +296,14 @@ class Processor(TriplesQueryService): "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), - value=query.o.value, + value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], query.o.value)) + triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" @@ -299,14 +311,14 @@ class Processor(TriplesQueryService): "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), - uri=query.o.value, + uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], query.o.value)) + triples.append((data["src"], data["rel"], get_term_value(query.o))) else: diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 0e84d733..4cb1ab21 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -10,12 +10,24 @@ import logging from neo4j import GraphDatabase from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error -from .... schema import Value, Triple +from .... schema import Term, Triple, IRI, LITERAL from .... base import TriplesQueryService # Module logger logger = logging.getLogger(__name__) + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + return term.id or term.value + default_ident = "triples-query" default_graph_host = 'bolt://neo4j:7687' @@ -47,9 +59,9 @@ class Processor(TriplesQueryService): def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): - return Value(value=ent, is_uri=True) + return Term(type=IRI, iri=ent) else: - return Value(value=ent, is_uri=False) + return Term(type=LITERAL, value=ent) async def query_triples(self, query): @@ -71,27 +83,29 @@ class Processor(TriplesQueryService): "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN $src as src", - src=query.s.value, rel=query.p.value, value=query.o.value, + "RETURN $src as src " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Node {uri: $uri, user: $user, collection: $collection}) " - "RETURN $src as src", - src=query.s.value, rel=query.p.value, uri=query.o.value, + "RETURN $src as src " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: - triples.append((query.s.value, query.p.value, query.o.value)) + triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) else: @@ -101,29 +115,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN dest.value as dest", - src=query.s.value, rel=query.p.value, + "RETURN dest.value as dest " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), rel=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, query.p.value, data["dest"])) + triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN dest.uri as dest", - src=query.s.value, rel=query.p.value, + "RETURN dest.uri as dest " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), rel=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, query.p.value, data["dest"])) + triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) else: @@ -135,29 +151,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN rel.uri as rel", - src=query.s.value, value=query.o.value, + "RETURN rel.uri as rel " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], query.o.value)) + triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {uri: $uri, user: $user, collection: $collection}) " - "RETURN rel.uri as rel", - src=query.s.value, uri=query.o.value, + "RETURN rel.uri as rel " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], query.o.value)) + triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) else: @@ -167,29 +185,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN rel.uri as rel, dest.value as dest", - src=query.s.value, + "RETURN rel.uri as rel, dest.value as dest " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], data["dest"])) + triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN rel.uri as rel, dest.uri as dest", - src=query.s.value, + "RETURN rel.uri as rel, dest.uri as dest " + "LIMIT " + str(query.limit), + src=get_term_value(query.s), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((query.s.value, data["rel"], data["dest"])) + triples.append((get_term_value(query.s), data["rel"], data["dest"])) else: @@ -204,29 +224,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN src.uri as src", - uri=query.p.value, value=query.o.value, + "RETURN src.uri as src " + "LIMIT " + str(query.limit), + uri=get_term_value(query.p), value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, query.o.value)) + triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Node {uri: $dest, user: $user, collection: $collection}) " - "RETURN src.uri as src", - uri=query.p.value, dest=query.o.value, + "RETURN src.uri as src " + "LIMIT " + str(query.limit), + uri=get_term_value(query.p), dest=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, query.o.value)) + triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) else: @@ -236,29 +258,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN src.uri as src, dest.value as dest", - uri=query.p.value, + "RETURN src.uri as src, dest.value as dest " + "LIMIT " + str(query.limit), + uri=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, data["dest"])) + triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN src.uri as src, dest.uri as dest", - uri=query.p.value, + "RETURN src.uri as src, dest.uri as dest " + "LIMIT " + str(query.limit), + uri=get_term_value(query.p), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], query.p.value, data["dest"])) + triples.append((data["src"], get_term_value(query.p), data["dest"])) else: @@ -270,29 +294,31 @@ class Processor(TriplesQueryService): "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {value: $value, user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel", - value=query.o.value, + "RETURN src.uri as src, rel.uri as rel " + "LIMIT " + str(query.limit), + value=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], query.o.value)) + triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {uri: $uri, user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel", - uri=query.o.value, + "RETURN src.uri as src, rel.uri as rel " + "LIMIT " + str(query.limit), + uri=get_term_value(query.o), user=user, collection=collection, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], query.o.value)) + triples.append((data["src"], data["rel"], get_term_value(query.o))) else: @@ -302,7 +328,8 @@ class Processor(TriplesQueryService): "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Literal {user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel, dest.value as dest", + "RETURN src.uri as src, rel.uri as rel, dest.value as dest " + "LIMIT " + str(query.limit), user=user, collection=collection, database_=self.db, ) @@ -315,7 +342,8 @@ class Processor(TriplesQueryService): "MATCH (src:Node {user: $user, collection: $collection})-" "[rel:Rel {user: $user, collection: $collection}]->" "(dest:Node {user: $user, collection: $collection}) " - "RETURN src.uri as src, rel.uri as rel, dest.uri as dest", + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " + "LIMIT " + str(query.limit), user=user, collection=collection, database_=self.db, ) @@ -327,10 +355,10 @@ class Processor(TriplesQueryService): triples = [ Triple( s=self.create_value(t[0]), - p=self.create_value(t[1]), + p=self.create_value(t[1]), o=self.create_value(t[2]) ) - for t in triples + for t in triples[:query.limit] ] return triples diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index 4b1a04a4..e39f9041 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -1,6 +1,6 @@ """ Structured Query Service - orchestrates natural language question processing. -Takes a question, converts it to GraphQL via nlp-query, executes via objects-query, +Takes a question, converts it to GraphQL via nlp-query, executes via rows-query, and returns the results. """ @@ -10,7 +10,7 @@ from typing import Dict, Any, Optional from ...schema import StructuredQueryRequest, StructuredQueryResponse from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse -from ...schema import ObjectsQueryRequest, ObjectsQueryResponse +from ...schema import RowsQueryRequest, RowsQueryResponse from ...schema import Error from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec @@ -57,13 +57,13 @@ class Processor(FlowProcessor): ) ) - # Client spec for calling objects query service + # Client spec for calling rows query service self.register_specification( RequestResponseSpec( - request_name = "objects-query-request", - response_name = "objects-query-response", - request_schema = ObjectsQueryRequest, - response_schema = ObjectsQueryResponse + request_name = "rows-query-request", + response_name = "rows-query-response", + request_schema = RowsQueryRequest, + response_schema = RowsQueryResponse ) ) @@ -112,7 +112,7 @@ class Processor(FlowProcessor): variables_as_strings[key] = str(value) # Use user/collection values from request - objects_request = ObjectsQueryRequest( + objects_request = RowsQueryRequest( user=request.user, collection=request.collection, query=nlp_response.graphql_query, @@ -120,12 +120,12 @@ class Processor(FlowProcessor): operation_name=None ) - objects_response = await flow("objects-query-request").request(objects_request) - + objects_response = await flow("rows-query-request").request(objects_request) + if objects_response.error is not None: - raise Exception(f"Objects query service error: {objects_response.error.message}") - - # Handle GraphQL errors from the objects query service + raise Exception(f"Rows query service error: {objects_response.error.message}") + + # Handle GraphQL errors from the rows query service graphql_errors = [] if objects_response.errors: for gql_error in objects_response.errors: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 07dbf0eb..ae869413 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_store_uri = 'http://localhost:19530' class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 6d1b23ba..a0e52253 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index edfa8aa9..cb978048 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_store_uri = 'http://localhost:6333' diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 2e192cd6..21aa21e6 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -9,11 +9,25 @@ from .... direct.milvus_graph_embeddings import EntityVectors from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + +default_ident = "graph-embeddings-write" default_store_uri = 'http://localhost:19530' class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): @@ -36,11 +50,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): async def store_graph_embeddings(self, message): for entity in message.entities: + entity_value = get_term_value(entity.entity) - if entity.entity.value != "" and entity.entity.value is not None: + if entity_value != "" and entity_value is not None: for vec in entity.vectors: self.vecstore.insert( - vec, entity.entity.value, + vec, entity_value, message.metadata.user, message.metadata.collection ) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 0bee6ceb..c4b0065b 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -14,11 +14,25 @@ import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + +default_ident = "graph-embeddings-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" @@ -100,8 +114,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): return for entity in message.entities: + entity_value = get_term_value(entity.entity) - if entity.entity.value == "" or entity.entity.value is None: + if entity_value == "" or entity_value is None: continue for vec in entity.vectors: @@ -126,7 +141,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): { "id": vector_id, "values": vec, - "metadata": { "entity": entity.entity.value }, + "metadata": { "entity": entity_value }, } ] diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index e3c2b6bc..0da59bb9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -12,11 +12,26 @@ import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) -default_ident = "ge-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + +default_ident = "graph-embeddings-write" default_store_uri = 'http://localhost:6333' @@ -51,8 +66,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): return for entity in message.entities: + entity_value = get_term_value(entity.entity) - if entity.entity.value == "" or entity.entity.value is None: return + if entity_value == "" or entity_value is None: + continue for vec in entity.vectors: @@ -80,7 +97,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): id=str(uuid.uuid4()), vector=vec, payload={ - "entity": entity.entity.value, + "entity": entity_value, } ) ] diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index a79b7b83..475604b6 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -64,12 +64,14 @@ class Processor(FlowProcessor): async def on_triples(self, msg, consumer, flow): v = msg.value() - await self.table_store.add_triples(v) + if v.triples: + await self.table_store.add_triples(v) async def on_graph_embeddings(self, msg, consumer, flow): v = msg.value() - await self.table_store.add_graph_embeddings(v) + if v.entities: + await self.table_store.add_graph_embeddings(v) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/__init__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/trustgraph-flow/trustgraph/storage/objects/__init__.py b/trustgraph-flow/trustgraph/storage/objects/__init__.py deleted file mode 100644 index 56f5f66a..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Objects storage module \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py deleted file mode 100644 index 01adc061..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . write import * diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py deleted file mode 100644 index 95376fee..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . write import run - -run() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py deleted file mode 100644 index bcb0d57f..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py +++ /dev/null @@ -1,538 +0,0 @@ -""" -Object writer for Cassandra. Input is ExtractedObject. -Writes structured objects to Cassandra tables based on schema definitions. -""" - -import json -import logging -from typing import Dict, Set, Optional, Any -from cassandra.cluster import Cluster -from cassandra.auth import PlainTextAuthProvider -from cassandra.cqlengine import connection -from cassandra import ConsistencyLevel - -from .... schema import ExtractedObject -from .... schema import RowSchema, Field -from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base import CollectionConfigHandler -from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config - -# Module logger -logger = logging.getLogger(__name__) - -default_ident = "objects-write" - -class Processor(CollectionConfigHandler, FlowProcessor): - - def __init__(self, **params): - - id = params.get("id", default_ident) - - # Get Cassandra parameters - cassandra_host = params.get("cassandra_host") - cassandra_username = params.get("cassandra_username") - cassandra_password = params.get("cassandra_password") - - # Resolve configuration with environment variable fallback - hosts, username, password, keyspace = resolve_cassandra_config( - host=cassandra_host, - username=cassandra_username, - password=cassandra_password - ) - - # Store resolved configuration with proper names - self.cassandra_host = hosts # Store as list - self.cassandra_username = username - self.cassandra_password = password - - # Config key for schemas - self.config_key = params.get("config_type", "schema") - - super(Processor, self).__init__( - **params | { - "id": id, - "config_type": self.config_key, - } - ) - - self.register_specification( - ConsumerSpec( - name = "input", - schema = ExtractedObject, - handler = self.on_object - ) - ) - - # Register config handlers - self.register_config_handler(self.on_schema_config) - self.register_config_handler(self.on_collection_config) - - # Cache of known keyspaces/tables - self.known_keyspaces: Set[str] = set() - self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables - - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - - # Cassandra session - self.cluster = None - self.session = None - - def connect_cassandra(self): - """Connect to Cassandra cluster""" - if self.session: - return - - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) - - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise - - async def on_schema_config(self, config, version): - """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - - # Check if our config type exists - if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") - return - - # Get the schemas dictionary for our type - schemas_config = config[self.config_key] - - # Process each schema in the schemas config - for schema_name, schema_json in schemas_config.items(): - try: - # Parse the JSON schema definition - schema_def = json.loads(schema_json) - - # Create Field objects - fields = [] - for field_def in schema_def.get("fields", []): - field = Field( - name=field_def["name"], - type=field_def["type"], - size=field_def.get("size", 0), - primary=field_def.get("primary_key", False), - description=field_def.get("description", ""), - required=field_def.get("required", False), - enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) - ) - fields.append(field) - - # Create RowSchema - row_schema = RowSchema( - name=schema_def.get("name", schema_name), - description=schema_def.get("description", ""), - fields=fields - ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - - except Exception as e: - logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") - - def ensure_keyspace(self, keyspace: str): - """Ensure keyspace exists in Cassandra""" - if keyspace in self.known_keyspaces: - return - - # Connect if needed - self.connect_cassandra() - - # Sanitize keyspace name - safe_keyspace = self.sanitize_name(keyspace) - - # Create keyspace if not exists - create_keyspace_cql = f""" - CREATE KEYSPACE IF NOT EXISTS {safe_keyspace} - WITH REPLICATION = {{ - 'class': 'SimpleStrategy', - 'replication_factor': 1 - }} - """ - - try: - self.session.execute(create_keyspace_cql) - self.known_keyspaces.add(keyspace) - self.known_tables[keyspace] = set() - logger.info(f"Ensured keyspace exists: {safe_keyspace}") - except Exception as e: - logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True) - raise - - def get_cassandra_type(self, field_type: str, size: int = 0) -> str: - """Convert schema field type to Cassandra type""" - # Handle None size - if size is None: - size = 0 - - type_mapping = { - "string": "text", - "integer": "bigint" if size > 4 else "int", - "float": "double" if size > 4 else "float", - "boolean": "boolean", - "timestamp": "timestamp", - "date": "date", - "time": "time", - "uuid": "uuid" - } - - return type_mapping.get(field_type, "text") - - def sanitize_name(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - # Replace non-alphanumeric characters with underscore - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - # Ensure it starts with a letter - if safe_name and not safe_name[0].isalpha(): - safe_name = 'o_' + safe_name - return safe_name.lower() - - def sanitize_table(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - # Replace non-alphanumeric characters with underscore - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - # Ensure it starts with a letter - safe_name = 'o_' + safe_name - return safe_name.lower() - - def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema): - """Ensure table exists with proper structure""" - table_key = f"{keyspace}.{table_name}" - if table_key in self.known_tables.get(keyspace, set()): - return - - # Ensure keyspace exists first - self.ensure_keyspace(keyspace) - - safe_keyspace = self.sanitize_name(keyspace) - safe_table = self.sanitize_table(table_name) - - # Build column definitions - columns = ["collection text"] # Collection is always part of table - primary_key_fields = [] - clustering_fields = [] - - for field in schema.fields: - safe_field_name = self.sanitize_name(field.name) - cassandra_type = self.get_cassandra_type(field.type, field.size) - columns.append(f"{safe_field_name} {cassandra_type}") - - if field.primary: - primary_key_fields.append(safe_field_name) - - # Build primary key - collection is always first in partition key - if primary_key_fields: - primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))" - else: - # If no primary key defined, use collection and a synthetic id - columns.append("synthetic_id uuid") - primary_key = "PRIMARY KEY ((collection, synthetic_id))" - - # Create table - create_table_cql = f""" - CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} ( - {', '.join(columns)}, - {primary_key} - ) - """ - - try: - self.session.execute(create_table_cql) - if keyspace not in self.known_tables: - self.known_tables[keyspace] = set() - self.known_tables[keyspace].add(table_key) - logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}") - - # Create secondary indexes for indexed fields - for field in schema.fields: - if field.indexed and not field.primary: - safe_field_name = self.sanitize_name(field.name) - index_name = f"{safe_table}_{safe_field_name}_idx" - create_index_cql = f""" - CREATE INDEX IF NOT EXISTS {index_name} - ON {safe_keyspace}.{safe_table} ({safe_field_name}) - """ - try: - self.session.execute(create_index_cql) - logger.info(f"Created index: {index_name}") - except Exception as e: - logger.warning(f"Failed to create index {index_name}: {e}") - - except Exception as e: - logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True) - raise - - def convert_value(self, value: Any, field_type: str) -> Any: - """Convert value to appropriate type for Cassandra""" - if value is None: - return None - - try: - if field_type == "integer": - return int(value) - elif field_type == "float": - return float(value) - elif field_type == "boolean": - if isinstance(value, str): - return value.lower() in ('true', '1', 'yes') - return bool(value) - elif field_type == "timestamp": - # Handle timestamp conversion if needed - return value - else: - return str(value) - except Exception as e: - logger.warning(f"Failed to convert value {value} to type {field_type}: {e}") - return str(value) - async def on_object(self, msg, consumer, flow): - """Process incoming ExtractedObject and store in Cassandra""" - - obj = msg.value() - logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}") - - # Validate collection exists before accepting writes - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): - error_msg = ( - f"Collection {obj.metadata.collection} does not exist. " - f"Create it first via collection management API." - ) - logger.error(error_msg) - raise ValueError(error_msg) - - # Get schema definition - schema = self.schemas.get(obj.schema_name) - if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") - return - - # Ensure table exists - keyspace = obj.metadata.user - table_name = obj.schema_name - self.ensure_table(keyspace, table_name, schema) - - # Prepare data for insertion - safe_keyspace = self.sanitize_name(keyspace) - safe_table = self.sanitize_table(table_name) - - # Process each object in the batch - for obj_index, value_map in enumerate(obj.values): - # Build column names and values for this object - columns = ["collection"] - values = [obj.metadata.collection] - placeholders = ["%s"] - - # Check if we need a synthetic ID - has_primary_key = any(field.primary for field in schema.fields) - if not has_primary_key: - import uuid - columns.append("synthetic_id") - values.append(uuid.uuid4()) - placeholders.append("%s") - - # Process fields for this object - skip_object = False - for field in schema.fields: - safe_field_name = self.sanitize_name(field.name) - raw_value = value_map.get(field.name) - - # Handle required fields - if field.required and raw_value is None: - logger.warning(f"Required field {field.name} is missing in object {obj_index}") - # Continue anyway - Cassandra doesn't enforce NOT NULL - - # Check if primary key field is NULL - if field.primary and raw_value is None: - logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}") - skip_object = True - break - - # Convert value to appropriate type - converted_value = self.convert_value(raw_value, field.type) - - columns.append(safe_field_name) - values.append(converted_value) - placeholders.append("%s") - - # Skip this object if primary key validation failed - if skip_object: - continue - - # Build and execute insert query for this object - insert_cql = f""" - INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)}) - VALUES ({', '.join(placeholders)}) - """ - - # Debug: Show data being inserted - logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}") - - if len(columns) != len(values) or len(columns) != len(placeholders): - raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}") - - try: - # Convert to tuple - Cassandra driver requires tuple for parameters - self.session.execute(insert_cql, tuple(values)) - except Exception as e: - logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True) - raise - - async def create_collection(self, user: str, collection: str, metadata: dict): - """Create/verify collection exists in Cassandra object store""" - # Connect if not already connected - self.connect_cassandra() - - # Sanitize names for safety - safe_keyspace = self.sanitize_name(user) - - # Ensure keyspace exists - if safe_keyspace not in self.known_keyspaces: - self.ensure_keyspace(safe_keyspace) - self.known_keyspaces.add(safe_keyspace) - - # For Cassandra objects, collection is just a property in rows - # No need to create separate tables per collection - # Just mark that we've seen this collection - logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})") - - async def delete_collection(self, user: str, collection: str): - """Delete all data for a specific collection using schema information""" - # Connect if not already connected - self.connect_cassandra() - - # Sanitize names for safety - safe_keyspace = self.sanitize_name(user) - - # Check if keyspace exists - if safe_keyspace not in self.known_keyspaces: - # Query to verify keyspace exists - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) - if not result.one(): - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(safe_keyspace) - - # Iterate over schemas we manage to delete from relevant tables - tables_deleted = 0 - - for schema_name, schema in self.schemas.items(): - safe_table = self.sanitize_table(schema_name) - - # Check if table exists - table_key = f"{user}.{schema_name}" - if table_key not in self.known_tables.get(user, set()): - logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping") - continue - - try: - # Get primary key fields from schema - primary_key_fields = [field for field in schema.fields if field.primary] - - if primary_key_fields: - # Schema has primary keys: need to query for partition keys first - # Build SELECT query for primary key fields - pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields] - select_cql = f""" - SELECT {', '.join(pk_field_names)} - FROM {safe_keyspace}.{safe_table} - WHERE collection = %s - ALLOW FILTERING - """ - - rows = self.session.execute(select_cql, (collection,)) - - # Delete each row using full partition key - for row in rows: - where_clauses = ["collection = %s"] - values = [collection] - - for field_name in pk_field_names: - where_clauses.append(f"{field_name} = %s") - values.append(getattr(row, field_name)) - - delete_cql = f""" - DELETE FROM {safe_keyspace}.{safe_table} - WHERE {' AND '.join(where_clauses)} - """ - - self.session.execute(delete_cql, tuple(values)) - else: - # No primary keys, uses synthetic_id - # Need to query for synthetic_ids first - select_cql = f""" - SELECT synthetic_id - FROM {safe_keyspace}.{safe_table} - WHERE collection = %s - ALLOW FILTERING - """ - - rows = self.session.execute(select_cql, (collection,)) - - # Delete each row using collection and synthetic_id - for row in rows: - delete_cql = f""" - DELETE FROM {safe_keyspace}.{safe_table} - WHERE collection = %s AND synthetic_id = %s - """ - self.session.execute(delete_cql, (collection, row.synthetic_id)) - - tables_deleted += 1 - logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}") - - except Exception as e: - logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}") - raise - - logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}") - - def close(self): - """Clean up Cassandra connections""" - if self.cluster: - self.cluster.shutdown() - logger.info("Closed Cassandra connection") - - @staticmethod - def add_args(parser): - """Add command-line arguments""" - - FlowProcessor.add_args(parser) - add_cassandra_args(parser) - - parser.add_argument( - '--config-type', - default='schema', - help='Configuration type prefix for schemas (default: schema)' - ) - -def run(): - """Entry point for objects-write-cassandra command""" - Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py new file mode 100644 index 00000000..16b2f154 --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py @@ -0,0 +1,3 @@ +""" +Row embeddings storage modules. +""" diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py new file mode 100644 index 00000000..65c5c514 --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py @@ -0,0 +1,5 @@ +""" +Qdrant storage for row embeddings. +""" + +from .write import Processor, run, default_ident diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py new file mode 100644 index 00000000..a349475c --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py @@ -0,0 +1,4 @@ + +from .write import run + +run() diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py new file mode 100644 index 00000000..29848c4c --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -0,0 +1,264 @@ +""" +Row embeddings writer for Qdrant (Stage 2). + +Consumes RowEmbeddings messages (which already contain computed vectors) +and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair. + +This follows the two-stage pattern used by graph-embeddings and document-embeddings: + Stage 1 (row-embeddings): Compute embeddings + Stage 2 (this processor): Store embeddings + +Collection naming: rows_{user}_{collection}_{schema_name}_{dimension} + +Payload structure: + - index_name: The indexed field(s) this embedding represents + - index_value: The original list of values (for Cassandra lookup) + - text: The text that was embedded (for debugging/display) +""" + +import logging +import re +import uuid +from typing import Set, Tuple + +from qdrant_client import QdrantClient +from qdrant_client.models import PointStruct, Distance, VectorParams + +from .... schema import RowEmbeddings +from .... base import FlowProcessor, ConsumerSpec +from .... base import CollectionConfigHandler + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings-write" +default_store_uri = 'http://localhost:6333' + + +class Processor(CollectionConfigHandler, FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) + + super(Processor, self).__init__( + **params | { + "id": id, + "store_uri": store_uri, + "api_key": api_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="input", + schema=RowEmbeddings, + handler=self.on_embeddings + ) + ) + + # Register config handler for collection management + self.register_config_handler(self.on_collection_config) + + # Cache of created Qdrant collections + self.created_collections: Set[str] = set() + + # Qdrant client + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Qdrant collection naming""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def get_collection_name( + self, user: str, collection: str, schema_name: str, dimension: int + ) -> str: + """Generate Qdrant collection name""" + safe_user = self.sanitize_name(user) + safe_collection = self.sanitize_name(collection) + safe_schema = self.sanitize_name(schema_name) + return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" + + def ensure_collection(self, collection_name: str, dimension: int): + """Create Qdrant collection if it doesn't exist""" + if collection_name in self.created_collections: + return + + if not self.qdrant.collection_exists(collection_name): + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" + ) + self.qdrant.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ) + ) + + self.created_collections.add(collection_name) + + async def on_embeddings(self, msg, consumer, flow): + """Process incoming RowEmbeddings and write to Qdrant""" + + embeddings = msg.value() + logger.info( + f"Writing {len(embeddings.embeddings)} embeddings for schema " + f"{embeddings.schema_name} from {embeddings.metadata.id}" + ) + + # Validate collection exists in config before processing + if not self.collection_exists( + embeddings.metadata.user, embeddings.metadata.collection + ): + logger.warning( + f"Collection {embeddings.metadata.collection} for user " + f"{embeddings.metadata.user} does not exist in config. " + f"Dropping message." + ) + return + + user = embeddings.metadata.user + collection = embeddings.metadata.collection + schema_name = embeddings.schema_name + + embeddings_written = 0 + qdrant_collection = None + + for row_emb in embeddings.embeddings: + if not row_emb.vectors: + logger.warning( + f"No vectors for index {row_emb.index_name} - skipping" + ) + continue + + # Use first vector (there may be multiple from different models) + for vector in row_emb.vectors: + dimension = len(vector) + + # Create/get collection name (lazily on first vector) + if qdrant_collection is None: + qdrant_collection = self.get_collection_name( + user, collection, schema_name, dimension + ) + self.ensure_collection(qdrant_collection, dimension) + + # Write to Qdrant + self.qdrant.upsert( + collection_name=qdrant_collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "index_name": row_emb.index_name, + "index_value": row_emb.index_value, + "text": row_emb.text + } + ) + ] + ) + embeddings_written += 1 + + logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Collection creation via config push - collections created lazily on first write""" + logger.info( + f"Row embeddings collection create request for {user}/{collection} - " + f"will be created lazily on first write" + ) + + async def delete_collection(self, user: str, collection: str): + """Delete all Qdrant collections for a given user/collection""" + try: + prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_" + + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No Qdrant collections found matching prefix {prefix}") + else: + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + self.created_collections.discard(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + logger.info( + f"Deleted {len(matching_collections)} collection(s) " + f"for {user}/{collection}" + ) + + except Exception as e: + logger.error( + f"Failed to delete collection {user}/{collection}: {e}", + exc_info=True + ) + raise + + async def delete_collection_schema( + self, user: str, collection: str, schema_name: str + ): + """Delete Qdrant collection for a specific user/collection/schema""" + try: + prefix = ( + f"rows_{self.sanitize_name(user)}_" + f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" + ) + + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No Qdrant collections found matching prefix {prefix}") + else: + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + self.created_collections.discard(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + + except Exception as e: + logger.error( + f"Failed to delete collection {user}/{collection}/{schema_name}: {e}", + exc_info=True + ) + raise + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + + parser.add_argument( + '-t', '--store-uri', + default=default_store_uri, + help=f'Qdrant URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help='Qdrant API key (default: None)' + ) + + +def run(): + """Entry point for row-embeddings-write-qdrant command""" + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index 1576b70c..d15916b6 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -1,46 +1,49 @@ - """ -Graph writer. Input is graph edge. Writes edges to Cassandra graph. +Row writer for Cassandra. Input is ExtractedObject. +Writes structured rows to a unified Cassandra table with multi-index support. + +Uses a single 'rows' table with the schema: + - collection: text + - schema_name: text + - index_name: text + - index_value: frozen> + - data: map + - source: text + +Each row is written multiple times - once per indexed field defined in the schema. """ -raise RuntimeError("This code is no longer in use") - -import pulsar -import base64 -import os -import argparse -import time +import json import logging +import re +from typing import Dict, Set, Optional, Any, List, Tuple + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider -from ssl import SSLContext, PROTOCOL_TLSv1_2 -from .... schema import Rows -from .... log_level import LogLevel -from .... base import Consumer +from .... schema import ExtractedObject +from .... schema import RowSchema, Field +from .... base import FlowProcessor, ConsumerSpec +from .... base import CollectionConfigHandler from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config # Module logger logger = logging.getLogger(__name__) -module = "rows-write" -ssl_context = SSLContext(PROTOCOL_TLSv1_2) +default_ident = "rows-write" -default_input_queue = "rows-store" # Default queue name -default_subscriber = module -class Processor(Consumer): +class Processor(CollectionConfigHandler, FlowProcessor): def __init__(self, **params): - - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - + + id = params.get("id", default_ident) + # Get Cassandra parameters cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - + # Resolve configuration with environment variable fallback hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, @@ -48,99 +51,549 @@ class Processor(Consumer): password=cassandra_password ) + # Store resolved configuration with proper names + self.cassandra_host = hosts # Store as list + self.cassandra_username = username + self.cassandra_password = password + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Rows, - "cassandra_host": ','.join(hosts), - "cassandra_username": username, - "cassandra_password": password, + "id": id, + "config_type": self.config_key, } ) - - if username and password: - auth_provider = PlainTextAuthProvider(username=username, password=password) - self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) - else: - self.cluster = Cluster(hosts) - self.session = self.cluster.connect() - self.tables = set() + self.register_specification( + ConsumerSpec( + name="input", + schema=ExtractedObject, + handler=self.on_object + ) + ) - self.session.execute(""" - create keyspace if not exists trustgraph - with replication = { - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 - }; - """); + # Register config handlers + self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_collection_config) - self.session.execute("use trustgraph"); + # Cache of known keyspaces and whether tables exist + self.known_keyspaces: Set[str] = set() + self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables - async def handle(self, msg): + # Cache of registered (collection, schema_name) pairs + self.registered_partitions: Set[Tuple[str, str]] = set() + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + # Cassandra session + self.cluster = None + self.session = None + + def connect_cassandra(self): + """Connect to Cassandra cluster""" + if self.session: + return try: - - v = msg.value() - name = v.row_schema.name - - if name not in self.tables: - - # FIXME: SQL injection? - - pkey = [] - - stmt = "create table if not exists " + name + " ( " - - for field in v.row_schema.fields: - - stmt += field.name + " text, " - - if field.primary: - pkey.append(field.name) - - stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));" - - self.session.execute(stmt) - - self.tables.add(name); - - for row in v.rows: - - field_names = [] - values = [] - - for field in v.row_schema.fields: - field_names.append(field.name) - values.append(row[field.name]) - - # FIXME: SQL injection? - stmt = ( - "insert into " + name + " (" + ", ".join(field_names) + - ") values (" + ",".join(["%s"] * len(values)) + ")" + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password ) + self.cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + self.cluster = Cluster(contact_points=self.cassandra_host) - self.session.execute(stmt, values) + self.session = self.cluster.connect() + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise - logger.error(f"Exception: {str(e)}", exc_info=True) + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") - # If there's an error make sure to do table creation etc. - self.tables.remove(name) + # Track which schemas changed so we can clear partition cache + old_schema_names = set(self.schemas.keys()) - raise e + # Clear existing schemas + self.schemas = {} + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = Field( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + # Clear partition cache for schemas that changed + # This ensures next write will re-register partitions + new_schema_names = set(self.schemas.keys()) + changed_schemas = old_schema_names.symmetric_difference(new_schema_names) + if changed_schemas: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if sch not in changed_schemas + } + logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}") + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Cassandra compatibility""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + # Ensure it starts with a letter + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def ensure_keyspace(self, keyspace: str): + """Ensure keyspace exists in Cassandra""" + if keyspace in self.known_keyspaces: + return + + # Connect if needed + self.connect_cassandra() + + # Sanitize keyspace name + safe_keyspace = self.sanitize_name(keyspace) + + # Create keyspace if not exists + create_keyspace_cql = f""" + CREATE KEYSPACE IF NOT EXISTS {safe_keyspace} + WITH REPLICATION = {{ + 'class': 'SimpleStrategy', + 'replication_factor': 1 + }} + """ + + try: + self.session.execute(create_keyspace_cql) + self.known_keyspaces.add(keyspace) + logger.info(f"Ensured keyspace exists: {safe_keyspace}") + except Exception as e: + logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True) + raise + + def ensure_tables(self, keyspace: str): + """Ensure unified rows and row_partitions tables exist""" + if keyspace in self.tables_initialized: + return + + # Ensure keyspace exists first + self.ensure_keyspace(keyspace) + + safe_keyspace = self.sanitize_name(keyspace) + + # Create unified rows table + create_rows_cql = f""" + CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows ( + collection text, + schema_name text, + index_name text, + index_value frozen>, + data map, + source text, + PRIMARY KEY ((collection, schema_name, index_name), index_value) + ) + """ + + # Create row_partitions tracking table + create_partitions_cql = f""" + CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions ( + collection text, + schema_name text, + index_name text, + PRIMARY KEY ((collection), schema_name, index_name) + ) + """ + + try: + self.session.execute(create_rows_cql) + logger.info(f"Ensured rows table exists: {safe_keyspace}.rows") + + self.session.execute(create_partitions_cql) + logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions") + + self.tables_initialized.add(keyspace) + + except Exception as e: + logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True) + raise + + def get_index_names(self, schema: RowSchema) -> List[str]: + """ + Get all index names for a schema. + Returns list of index_name strings (single field names or comma-joined composites). + """ + index_names = [] + + for field in schema.fields: + # Primary key fields are treated as indexes + if field.primary: + index_names.append(field.name) + # Indexed fields + elif field.indexed: + index_names.append(field.name) + + # TODO: Support composite indexes in the future + # For now, each indexed field is a single-field index + + return index_names + + def register_partitions(self, keyspace: str, collection: str, schema_name: str): + """ + Register partition entries for a (collection, schema_name) pair. + Called once on first row for each pair. + """ + cache_key = (collection, schema_name) + if cache_key in self.registered_partitions: + return + + schema = self.schemas.get(schema_name) + if not schema: + logger.warning(f"Cannot register partitions - schema {schema_name} not found") + return + + safe_keyspace = self.sanitize_name(keyspace) + index_names = self.get_index_names(schema) + + # Insert partition entries for each index + insert_cql = f""" + INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name) + VALUES (%s, %s, %s) + """ + + for index_name in index_names: + try: + self.session.execute(insert_cql, (collection, schema_name, index_name)) + except Exception as e: + logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}") + + self.registered_partitions.add(cache_key) + logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}") + + def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]: + """ + Build the index_value list for a given index. + For single-field indexes, returns a single-element list. + For composite indexes (comma-separated), returns multiple elements. + """ + field_names = [f.strip() for f in index_name.split(',')] + values = [] + + for field_name in field_names: + value = value_map.get(field_name) + # Convert to string for storage + values.append(str(value) if value is not None else "") + + return values + + async def on_object(self, msg, consumer, flow): + """Process incoming ExtractedObject and store in Cassandra""" + + obj = msg.value() + logger.info( + f"Storing {len(obj.values)} rows for schema {obj.schema_name} " + f"from {obj.metadata.id}" + ) + + # Validate collection exists before accepting writes + if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + error_msg = ( + f"Collection {obj.metadata.collection} does not exist. " + f"Create it first via collection management API." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Get schema definition + schema = self.schemas.get(obj.schema_name) + if not schema: + logger.warning(f"No schema found for {obj.schema_name} - skipping") + return + + keyspace = obj.metadata.user + collection = obj.metadata.collection + schema_name = obj.schema_name + source = getattr(obj.metadata, 'source', '') or '' + + # Ensure tables exist + self.ensure_tables(keyspace) + + # Register partitions if first time seeing this (collection, schema_name) + self.register_partitions(keyspace, collection, schema_name) + + safe_keyspace = self.sanitize_name(keyspace) + + # Get all index names for this schema + index_names = self.get_index_names(schema) + + if not index_names: + logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable") + return + + # Prepare insert statement + insert_cql = f""" + INSERT INTO {safe_keyspace}.rows + (collection, schema_name, index_name, index_value, data, source) + VALUES (%s, %s, %s, %s, %s, %s) + """ + + # Process each row in the batch + rows_written = 0 + for row_index, value_map in enumerate(obj.values): + # Convert all values to strings for the data map + data_map = {} + for field in schema.fields: + raw_value = value_map.get(field.name) + if raw_value is not None: + data_map[field.name] = str(raw_value) + + # Write one copy per index + for index_name in index_names: + index_value = self.build_index_value(value_map, index_name) + + # Skip if index value is empty/null + if not index_value or all(v == "" for v in index_value): + logger.debug( + f"Skipping index {index_name} for row {row_index} - " + f"empty index value" + ) + continue + + try: + self.session.execute( + insert_cql, + (collection, schema_name, index_name, index_value, data_map, source) + ) + rows_written += 1 + except Exception as e: + logger.error( + f"Failed to insert row {row_index} for index {index_name}: {e}", + exc_info=True + ) + raise + + logger.info( + f"Wrote {rows_written} index entries for {len(obj.values)} rows " + f"({len(index_names)} indexes per row)" + ) + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create/verify collection exists in Cassandra row store""" + # Connect if not already connected + self.connect_cassandra() + + # Ensure tables exist + self.ensure_tables(user) + + logger.info(f"Collection {collection} ready for user {user}") + + async def delete_collection(self, user: str, collection: str): + """Delete all data for a specific collection using partition tracking""" + # Connect if not already connected + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Check if keyspace exists + if user not in self.known_keyspaces: + check_keyspace_cql = """ + SELECT keyspace_name FROM system_schema.keyspaces + WHERE keyspace_name = %s + """ + result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) + if not result.one(): + logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") + return + self.known_keyspaces.add(user) + + # Discover all partitions for this collection + select_partitions_cql = f""" + SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions + WHERE collection = %s + """ + + try: + partitions = self.session.execute(select_partitions_cql, (collection,)) + partition_list = list(partitions) + except Exception as e: + logger.error(f"Failed to query partitions for collection {collection}: {e}") + raise + + # Delete each partition from rows table + delete_rows_cql = f""" + DELETE FROM {safe_keyspace}.rows + WHERE collection = %s AND schema_name = %s AND index_name = %s + """ + + partitions_deleted = 0 + for partition in partition_list: + try: + self.session.execute( + delete_rows_cql, + (collection, partition.schema_name, partition.index_name) + ) + partitions_deleted += 1 + except Exception as e: + logger.error( + f"Failed to delete partition {collection}/{partition.schema_name}/" + f"{partition.index_name}: {e}" + ) + raise + + # Clean up row_partitions entries + delete_partitions_cql = f""" + DELETE FROM {safe_keyspace}.row_partitions + WHERE collection = %s + """ + + try: + self.session.execute(delete_partitions_cql, (collection,)) + except Exception as e: + logger.error(f"Failed to clean up row_partitions for {collection}: {e}") + raise + + # Clear from local cache + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } + + logger.info( + f"Deleted collection {collection}: {partitions_deleted} partitions " + f"from keyspace {safe_keyspace}" + ) + + async def delete_collection_schema(self, user: str, collection: str, schema_name: str): + """Delete all data for a specific collection + schema combination""" + # Connect if not already connected + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Discover partitions for this collection + schema + select_partitions_cql = f""" + SELECT index_name FROM {safe_keyspace}.row_partitions + WHERE collection = %s AND schema_name = %s + """ + + try: + partitions = self.session.execute(select_partitions_cql, (collection, schema_name)) + partition_list = list(partitions) + except Exception as e: + logger.error( + f"Failed to query partitions for {collection}/{schema_name}: {e}" + ) + raise + + # Delete each partition from rows table + delete_rows_cql = f""" + DELETE FROM {safe_keyspace}.rows + WHERE collection = %s AND schema_name = %s AND index_name = %s + """ + + partitions_deleted = 0 + for partition in partition_list: + try: + self.session.execute( + delete_rows_cql, + (collection, schema_name, partition.index_name) + ) + partitions_deleted += 1 + except Exception as e: + logger.error( + f"Failed to delete partition {collection}/{schema_name}/" + f"{partition.index_name}: {e}" + ) + raise + + # Clean up row_partitions entries for this schema + delete_partitions_cql = f""" + DELETE FROM {safe_keyspace}.row_partitions + WHERE collection = %s AND schema_name = %s + """ + + try: + self.session.execute(delete_partitions_cql, (collection, schema_name)) + except Exception as e: + logger.error( + f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}" + ) + raise + + # Clear from local cache + self.registered_partitions.discard((collection, schema_name)) + + logger.info( + f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " + f"from keyspace {safe_keyspace}" + ) + + def close(self): + """Clean up Cassandra connections""" + if self.cluster: + self.cluster.shutdown() + logger.info("Closed Cassandra connection") @staticmethod def add_args(parser): + """Add command-line arguments""" - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + FlowProcessor.add_args(parser) add_cassandra_args(parser) + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + def run(): - - Processor.launch(module, __doc__) - + """Entry point for rows-write-cassandra command""" + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index b9b42375..5bc842de 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -10,11 +10,14 @@ import argparse import time import logging -from .... direct.cassandra_kg import KnowledgeGraph +from .... direct.cassandra_kg import ( + EntityCentricKnowledgeGraph, DEFAULT_GRAPH +) from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... schema import IRI, LITERAL, BLANK, TRIPLE # Module logger logger = logging.getLogger(__name__) @@ -22,6 +25,59 @@ logger = logging.getLogger(__name__) default_ident = "triples-write" +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + +def get_term_otype(term): + """ + Get object type code from a Term for entity-centric storage. + + Maps Term.type to otype: + - IRI ("i") → "u" (URI) + - BLANK ("b") → "u" (treated as URI) + - LITERAL ("l") → "l" (Literal) + - TRIPLE ("t") → "t" (Triple/reification) + """ + if term is None: + return "u" + if term.type == IRI or term.type == BLANK: + return "u" + elif term.type == LITERAL: + return "l" + elif term.type == TRIPLE: + return "t" + else: + return "u" + + +def get_term_dtype(term): + """Extract datatype from a Term (for literals)""" + if term is None: + return "" + if term.type == LITERAL: + return term.datatype or "" + return "" + + +def get_term_lang(term): + """Extract language tag from a Term (for literals)""" + if term is None: + return "" + if term.type == LITERAL: + return term.language or "" + return "" + + class Processor(CollectionConfigHandler, TriplesStoreService): def __init__(self, **params): @@ -64,15 +120,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.tg = None + # Use factory function to select implementation + KGClass = EntityCentricKnowledgeGraph + try: if self.cassandra_username and self.cassandra_password: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=message.metadata.user, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=message.metadata.user, ) @@ -84,11 +143,27 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.table = user for t in message.triples: + # Extract values from Term objects + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + # t.g is None for default graph, or a graph IRI + g_val = t.g if t.g is not None else DEFAULT_GRAPH + + # Extract object type metadata for entity-centric storage + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) + self.tg.insert( message.metadata.collection, - t.s.value, - t.p.value, - t.o.value + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang ) async def create_collection(self, user: str, collection: str, metadata: dict): @@ -98,16 +173,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.table is None or self.table != user: self.tg = None + # Use factory function to select implementation + KGClass = EntityCentricKnowledgeGraph + try: if self.cassandra_username and self.cassandra_password: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=user, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=user, ) @@ -137,16 +215,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.table is None or self.table != user: self.tg = None + # Use factory function to select implementation + KGClass = EntityCentricKnowledgeGraph + try: if self.cassandra_username and self.cassandra_password: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=user, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = KnowledgeGraph( + self.tg = KGClass( hosts=self.cassandra_host, keyspace=user, ) diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index f08eeb91..210ea53d 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -15,12 +15,27 @@ from falkordb import FalkorDB from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) default_ident = "triples-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + default_graph_url = 'falkor://falkordb:6379' default_database = 'falkordb' @@ -164,14 +179,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): for t in message.triples: - self.create_node(t.s.value, user, collection) + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) - if t.o.is_uri: - self.create_node(t.o.value, user, collection) - self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) + self.create_node(s_val, user, collection) + + if t.o.type == IRI: + self.create_node(o_val, user, collection) + self.relate_node(s_val, p_val, o_val, user, collection) else: - self.create_literal(t.o.value, user, collection) - self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) + self.create_literal(o_val, user, collection) + self.relate_literal(s_val, p_val, o_val, user, collection) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 8105b14e..55d4dee1 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -15,12 +15,27 @@ from neo4j import GraphDatabase from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) default_ident = "triples-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + default_graph_host = 'bolt://memgraph:7687' default_username = 'memgraph' default_password = 'password' @@ -204,40 +219,44 @@ class Processor(CollectionConfigHandler, TriplesStoreService): def create_triple(self, tx, t, user, collection): + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + # Create new s node with given uri, if not exists result = tx.run( "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=t.s.value, user=user, collection=collection + uri=s_val, user=user, collection=collection ) - if t.o.is_uri: + if t.o.type == IRI: # Create new o node with given uri, if not exists result = tx.run( "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=t.o.value, user=user, collection=collection + uri=o_val, user=user, collection=collection ) result = tx.run( "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection, + src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, ) else: - + # Create new o literal with given uri, if not exists result = tx.run( "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=t.o.value, user=user, collection=collection + value=o_val, user=user, collection=collection ) result = tx.run( "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection, + src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, ) async def store_triples(self, message): @@ -257,14 +276,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): for t in message.triples: - self.create_node(t.s.value, user, collection) + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) - if t.o.is_uri: - self.create_node(t.o.value, user, collection) - self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) + self.create_node(s_val, user, collection) + + if t.o.type == IRI: + self.create_node(o_val, user, collection) + self.relate_node(s_val, p_val, o_val, user, collection) else: - self.create_literal(t.o.value, user, collection) - self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) + self.create_literal(o_val, user, collection) + self.relate_literal(s_val, p_val, o_val, user, collection) # Alternative implementation using transactions # with self.io.session(database=self.db) as session: diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index e33b26ca..4a85a273 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -14,12 +14,27 @@ from neo4j import GraphDatabase from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import IRI, LITERAL # Module logger logger = logging.getLogger(__name__) default_ident = "triples-write" + +def get_term_value(term): + """Extract the string value from a Term""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + else: + # For blank nodes or other types, use id or value + return term.id or term.value + + default_graph_host = 'bolt://neo4j:7687' default_username = 'neo4j' default_password = 'password' @@ -212,14 +227,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): for t in message.triples: - self.create_node(t.s.value, user, collection) + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) - if t.o.is_uri: - self.create_node(t.o.value, user, collection) - self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) + self.create_node(s_val, user, collection) + + if t.o.type == IRI: + self.create_node(o_val, user, collection) + self.relate_node(s_val, p_val, o_val, user, collection) else: - self.create_literal(t.o.value, user, collection) - self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) + self.create_literal(o_val, user, collection) + self.relate_literal(s_val, p_val, o_val, user, collection) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index f98929e1..fb9ea0a7 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -1,6 +1,6 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings -from .. schema import Metadata, Value, GraphEmbeddings +from .. schema import Metadata, GraphEmbeddings from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 1ee61088..6ea16499 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -1,8 +1,24 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings -from .. schema import Metadata, Value, GraphEmbeddings +from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings from cassandra.cluster import Cluster + + +def term_to_tuple(term): + """Convert Term to (value, is_uri) tuple for database storage.""" + if term.type == IRI: + return (term.iri, True) + else: # LITERAL + return (term.value, False) + + +def tuple_to_term(value, is_uri): + """Convert (value, is_uri) tuple from database to Term.""" + if is_uri: + return Term(type=IRI, iri=value) + else: + return Term(type=LITERAL, value=value) from cassandra.auth import PlainTextAuthProvider from ssl import SSLContext, PROTOCOL_TLSv1_2 @@ -205,8 +221,7 @@ class KnowledgeTableStore: if m.metadata.metadata: metadata = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in m.metadata.metadata ] @@ -215,8 +230,7 @@ class KnowledgeTableStore: triples = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in m.triples ] @@ -248,8 +262,7 @@ class KnowledgeTableStore: if m.metadata.metadata: metadata = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in m.metadata.metadata ] @@ -258,7 +271,7 @@ class KnowledgeTableStore: entities = [ ( - (v.entity.value, v.entity.is_uri), + term_to_tuple(v.entity), v.vectors ) for v in m.entities @@ -291,8 +304,7 @@ class KnowledgeTableStore: if m.metadata.metadata: metadata = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in m.metadata.metadata ] @@ -414,23 +426,26 @@ class KnowledgeTableStore: if row[2]: metadata = [ Triple( - s = Value(value = elt[0], is_uri = elt[1]), - p = Value(value = elt[2], is_uri = elt[3]), - o = Value(value = elt[4], is_uri = elt[5]), + s = tuple_to_term(elt[0], elt[1]), + p = tuple_to_term(elt[2], elt[3]), + o = tuple_to_term(elt[4], elt[5]), ) for elt in row[2] ] else: metadata = [] - triples = [ - Triple( - s = Value(value = elt[0], is_uri = elt[1]), - p = Value(value = elt[2], is_uri = elt[3]), - o = Value(value = elt[4], is_uri = elt[5]), - ) - for elt in row[3] - ] + if row[3]: + triples = [ + Triple( + s = tuple_to_term(elt[0], elt[1]), + p = tuple_to_term(elt[2], elt[3]), + o = tuple_to_term(elt[4], elt[5]), + ) + for elt in row[3] + ] + else: + triples = [] await receiver( Triples( @@ -470,22 +485,25 @@ class KnowledgeTableStore: if row[2]: metadata = [ Triple( - s = Value(value = elt[0], is_uri = elt[1]), - p = Value(value = elt[2], is_uri = elt[3]), - o = Value(value = elt[4], is_uri = elt[5]), + s = tuple_to_term(elt[0], elt[1]), + p = tuple_to_term(elt[2], elt[3]), + o = tuple_to_term(elt[4], elt[5]), ) for elt in row[2] ] else: metadata = [] - entities = [ - EntityEmbeddings( - entity = Value(value = ent[0][0], is_uri = ent[0][1]), - vectors = ent[1] - ) - for ent in row[3] - ] + if row[3]: + entities = [ + EntityEmbeddings( + entity = tuple_to_term(ent[0][0], ent[0][1]), + vectors = ent[1] + ) + for ent in row[3] + ] + else: + entities = [] await receiver( GraphEmbeddings( diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 0a7c6081..8bbe2bad 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -1,8 +1,24 @@ from .. schema import LibrarianRequest, LibrarianResponse from .. schema import DocumentMetadata, ProcessingMetadata -from .. schema import Error, Triple, Value +from .. schema import Error, Triple, Term, IRI, LITERAL from .. knowledge import hash + + +def term_to_tuple(term): + """Convert Term to (value, is_uri) tuple for database storage.""" + if term.type == IRI: + return (term.iri, True) + else: # LITERAL + return (term.value, False) + + +def tuple_to_term(value, is_uri): + """Convert (value, is_uri) tuple from database to Term.""" + if is_uri: + return Term(type=IRI, iri=value) + else: + return Term(type=LITERAL, value=value) from .. exceptions import RequestError from cassandra.cluster import Cluster @@ -215,8 +231,7 @@ class LibraryTableStore: metadata = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in document.metadata ] @@ -249,8 +264,7 @@ class LibraryTableStore: metadata = [ ( - v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, - v.o.value, v.o.is_uri + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) ) for v in document.metadata ] @@ -331,9 +345,9 @@ class LibraryTableStore: comments = row[4], metadata = [ Triple( - s=Value(value=m[0], is_uri=m[1]), - p=Value(value=m[2], is_uri=m[3]), - o=Value(value=m[4], is_uri=m[5]) + s=tuple_to_term(m[0], m[1]), + p=tuple_to_term(m[2], m[3]), + o=tuple_to_term(m[4], m[5]) ) for m in row[5] ], @@ -376,9 +390,9 @@ class LibraryTableStore: comments = row[3], metadata = [ Triple( - s=Value(value=m[0], is_uri=m[1]), - p=Value(value=m[2], is_uri=m[3]), - o=Value(value=m[4], is_uri=m[5]) + s=tuple_to_term(m[0], m[1]), + p=tuple_to_term(m[2], m[3]), + o=tuple_to_term(m[4], m[5]) ) for m in row[4] ], diff --git a/trustgraph-flow/trustgraph/template/prompt_manager.py b/trustgraph-flow/trustgraph/template/prompt_manager.py index 9364cf21..546a7faf 100644 --- a/trustgraph-flow/trustgraph/template/prompt_manager.py +++ b/trustgraph-flow/trustgraph/template/prompt_manager.py @@ -83,7 +83,7 @@ class PromptManager: def parse_json(self, text): json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL) - + if json_match: json_str = json_match.group(1).strip() else: @@ -92,6 +92,43 @@ class PromptManager: return json.loads(json_str) + def parse_jsonl(self, text): + """ + Parse JSONL response, returning list of valid objects. + + Invalid lines (malformed JSON, empty lines) are skipped with warnings. + This provides truncation resilience - partial output yields partial results. + """ + results = [] + + # Strip markdown code fences if present + text = text.strip() + if text.startswith('```'): + # Remove opening fence (possibly with language hint) + text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text) + if text.endswith('```'): + text = text[:-3] + + for line_num, line in enumerate(text.strip().split('\n'), 1): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Skip any remaining fence markers + if line.startswith('```'): + continue + + try: + obj = json.loads(line) + results.append(obj) + except json.JSONDecodeError as e: + # Log warning but continue - this provides truncation resilience + logger.warning(f"JSONL parse error on line {line_num}: {e}") + + return results + def render(self, id, input): if id not in self.prompts: @@ -121,21 +158,41 @@ class PromptManager: if resp_type == "text": return resp - if resp_type != "json": - raise RuntimeError(f"Response type {resp_type} not known") - - try: - obj = self.parse_json(resp) - except: - logger.error(f"JSON parse failed: {resp}") - raise RuntimeError("JSON parse fail") - - if self.prompts[id].schema: + if resp_type == "json": try: - validate(instance=obj, schema=self.prompts[id].schema) - logger.debug("Schema validation successful") - except Exception as e: - raise RuntimeError(f"Schema validation fail: {e}") + obj = self.parse_json(resp) + except: + logger.error(f"JSON parse failed: {resp}") + raise RuntimeError("JSON parse fail") - return obj + if self.prompts[id].schema: + try: + validate(instance=obj, schema=self.prompts[id].schema) + logger.debug("Schema validation successful") + except Exception as e: + raise RuntimeError(f"Schema validation fail: {e}") + + return obj + + if resp_type == "jsonl": + objects = self.parse_jsonl(resp) + + if not objects: + logger.warning("JSONL parse returned no valid objects") + return [] + + # Validate each object against schema if provided + if self.prompts[id].schema: + validated = [] + for i, obj in enumerate(objects): + try: + validate(instance=obj, schema=self.prompts[id].schema) + validated.append(obj) + except Exception as e: + logger.warning(f"Object {i} failed schema validation: {e}") + return validated + + return objects + + raise RuntimeError(f"Response type {resp_type} not known") diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index 1068d91e..d089180a 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index a96f8338..48f92207 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,9 +10,10 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.8,<1.9", + "trustgraph-base>=2.0,<2.1", "pulsar-client", - "google-cloud-aiplatform", + "google-genai", + "google-api-core", "prometheus-client", "anthropic", ] @@ -25,6 +26,7 @@ classifiers = [ Homepage = "https://github.com/trustgraph-ai/trustgraph" [project.scripts] +text-completion-googleaistudio = "trustgraph.model.text_completion.googleaistudio:run" text-completion-vertexai = "trustgraph.model.text_completion.vertexai:run" [tool.setuptools.packages.find] diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__init__.py b/trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__init__.py rename to trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/__init__.py diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__main__.py b/trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/__main__.py similarity index 100% rename from trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__main__.py rename to trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/__main__.py diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/llm.py similarity index 90% rename from trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py rename to trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/llm.py index 1e9160ed..d241d6f2 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/googleaistudio/llm.py @@ -15,6 +15,7 @@ Input is prompt, output is response. from google import genai from google.genai import types from google.genai.types import HarmCategory, HarmBlockThreshold +from google.genai.errors import ClientError from google.api_core.exceptions import ResourceExhausted import os import logging @@ -52,7 +53,7 @@ class Processor(LlmService): } ) - self.client = genai.Client(api_key=api_key) + self.client = genai.Client(api_key=api_key, vertexai=False) self.default_model = model self.temperature = temperature self.max_output = max_output @@ -152,6 +153,15 @@ class Processor(LlmService): # Leave rate limit retries to the default handler raise TooManyRequests() + except ClientError as e: + # google-genai SDK throws ClientError for 4xx errors + if e.code == 429: + logger.warning(f"Rate limit exceeded (ClientError 429): {e}") + raise TooManyRequests() + # Other client errors are unrecoverable + logger.error(f"GoogleAIStudio ClientError: {e}", exc_info=True) + raise e + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable @@ -216,6 +226,15 @@ class Processor(LlmService): logger.warning("Rate limit exceeded during streaming") raise TooManyRequests() + except ClientError as e: + # google-genai SDK throws ClientError for 4xx errors + if e.code == 429: + logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}") + raise TooManyRequests() + # Other client errors are unrecoverable + logger.error(f"GoogleAIStudio streaming ClientError: {e}", exc_info=True) + raise e + except Exception as e: logger.error(f"GoogleAIStudio streaming exception ({type(e).__name__}): {e}", exc_info=True) raise e diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 5cf17b4d..d7a7dd2a 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -4,29 +4,20 @@ Google Cloud. Input is prompt, output is response. Supports both Google's Gemini models and Anthropic's Claude models. """ -# -# Somewhat perplexed by the Google Cloud SDK choices. We're going off this -# one, which uses the google-cloud-aiplatform library: -# https://cloud.google.com/python/docs/reference/vertexai/1.94.0 -# It seems it is possible to invoke VertexAI from the google-genai -# SDK too: -# https://googleapis.github.io/python-genai/genai.html#module-genai.client -# That would make this code look very much like the GoogleAIStudio -# code. And maybe not reliant on the google-cloud-aiplatform library? # -# This module's imports bring in a lot of libraries. +# Uses the google-genai SDK for Gemini models on Vertex AI: +# https://googleapis.github.io/python-genai/genai.html#module-genai.client +# from google.oauth2 import service_account import google.auth -import google.api_core.exceptions -import vertexai import logging -# Why is preview here? -from vertexai.generative_models import ( - Content, FunctionDeclaration, GenerativeModel, GenerationConfig, - HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, -) +from google import genai +from google.genai import types +from google.genai.types import HarmCategory, HarmBlockThreshold +from google.genai.errors import ClientError +from google.api_core.exceptions import ResourceExhausted # Added for Anthropic model support from anthropic import AnthropicVertex, RateLimitError @@ -67,12 +58,10 @@ class Processor(LlmService): self.max_output = max_output self.private_key = private_key - # Model client caches - self.model_clients = {} # Cache for model instances - self.generation_configs = {} # Cache for generation configs (Gemini only) - self.anthropic_client = None # Single Anthropic client (handles multiple models) + # Anthropic client (handles Claude models) + self.anthropic_client = None - # Shared parameters for both model types + # Shared parameters for Anthropic models self.api_params = { "temperature": temperature, "top_p": 1.0, @@ -84,10 +73,10 @@ class Processor(LlmService): # Unified credential and project ID loading if private_key: - credentials = ( - service_account.Credentials.from_service_account_file( - private_key - ) + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + credentials = service_account.Credentials.from_service_account_file( + private_key, + scopes=scopes ) project_id = credentials.project_id else: @@ -103,12 +92,13 @@ class Processor(LlmService): self.credentials = credentials self.project_id = project_id - # Initialize Vertex AI SDK for Gemini models - init_kwargs = {'location': region, 'project': project_id} - if credentials and private_key: # Pass credentials only if from a file - init_kwargs['credentials'] = credentials - - vertexai.init(**init_kwargs) + # Initialize Google GenAI client for Gemini models + self.client = genai.Client( + vertexai=True, + project=project_id, + location=region, + credentials=credentials + ) # Pre-initialize Anthropic client if needed (single client handles all Claude models) if 'claude' in self.default_model.lower(): @@ -117,24 +107,27 @@ class Processor(LlmService): # Safety settings for Gemini models block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH self.safety_settings = [ - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=block_level, ), ] + # Cache for generation configs + self.generation_configs = {} + logger.info("VertexAI initialization complete") def _get_anthropic_client(self): @@ -152,25 +145,26 @@ class Processor(LlmService): return self.anthropic_client - def _get_gemini_model(self, model_name, temperature=None): - """Get or create a Gemini model instance""" - if model_name not in self.model_clients: - logger.info(f"Creating GenerativeModel instance for '{model_name}'") - self.model_clients[model_name] = GenerativeModel(model_name) - + def _get_or_create_config(self, model_name, temperature=None): + """Get or create generation config with dynamic temperature""" # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature - # Create generation config with the effective temperature - generation_config = GenerationConfig( - temperature=effective_temperature, - top_p=1.0, - top_k=10, - candidate_count=1, - max_output_tokens=self.max_output, - ) + # Create cache key that includes temperature to avoid conflicts + cache_key = f"{model_name}:{effective_temperature}" - return self.model_clients[model_name], generation_config + if cache_key not in self.generation_configs: + logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}") + self.generation_configs[cache_key] = types.GenerateContentConfig( + temperature=effective_temperature, + top_p=1.0, + top_k=40, + max_output_tokens=self.max_output, + response_mime_type="text/plain", + safety_settings=self.safety_settings, + ) + + return self.generation_configs[cache_key] async def generate_content(self, system, prompt, model=None, temperature=None): @@ -205,22 +199,24 @@ class Processor(LlmService): model=model_name ) else: - # Gemini API combines system and user prompts + # Gemini API using google-genai SDK logger.debug(f"Sending request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + # Set system instruction per request (can't be cached) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, generation_config = generation_config, - safety_settings = self.safety_settings, + response = self.client.models.generate_content( + model=model_name, + config=generation_config, + contents=prompt, ) resp = LlmResult( - text = response.text, - in_token = response.usage_metadata.prompt_token_count, - out_token = response.usage_metadata.candidates_token_count, - model = model_name + text=response.text, + in_token=int(response.usage_metadata.prompt_token_count), + out_token=int(response.usage_metadata.candidates_token_count), + model=model_name ) logger.info(f"Input Tokens: {resp.in_token}") @@ -229,11 +225,20 @@ class Processor(LlmService): return resp - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit: {e}") # Leave rate limit retries to the base handler raise TooManyRequests() + except ClientError as e: + # google-genai SDK throws ClientError for 4xx errors + if e.code == 429: + logger.warning(f"Hit rate limit (ClientError 429): {e}") + raise TooManyRequests() + # Other client errors are unrecoverable + logger.error(f"VertexAI ClientError: {e}", exc_info=True) + raise e + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable logger.error(f"VertexAI LLM exception: {e}", exc_info=True) @@ -302,17 +307,16 @@ class Processor(LlmService): logger.info(f"Output Tokens: {total_out_tokens}") else: - # Gemini streaming + # Gemini streaming using google-genai SDK logger.debug(f"Streaming request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, - generation_config=generation_config, - safety_settings=self.safety_settings, - stream=True # Enable streaming + response = self.client.models.generate_content_stream( + model=model_name, + config=generation_config, + contents=prompt, ) total_in_tokens = 0 @@ -348,10 +352,19 @@ class Processor(LlmService): logger.info(f"Input Tokens: {total_in_tokens}") logger.info(f"Output Tokens: {total_out_tokens}") - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit during streaming: {e}") raise TooManyRequests() + except ClientError as e: + # google-genai SDK throws ClientError for 4xx errors + if e.code == 429: + logger.warning(f"Hit rate limit during streaming (ClientError 429): {e}") + raise TooManyRequests() + # Other client errors are unrecoverable + logger.error(f"VertexAI streaming ClientError: {e}", exc_info=True) + raise e + except Exception as e: logger.error(f"VertexAI streaming exception: {e}", exc_info=True) raise e