From e81418c58f4e430c17226853257501d9a04299ac Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 8 Apr 2026 13:37:02 +0100 Subject: [PATCH] fix: preserve literal types in focus quoted triples and document tracing (#769) The triples client returns Uri/Literal (str subclasses), not Term objects. _quoted_triple() treated all values as IRIs, so literal objects like skos:definition values were mistyped in focus provenance events, and trace_source_documents could not match them in the store. Added to_term() to convert Uri/Literal back to Term, threaded a term_map from follow_edges_batch through get_subgraph/get_labelgraph into uri_map, and updated _quoted_triple to accept Term objects directly. --- tests/unit/test_retrieval/test_graph_rag.py | 15 +++-- .../trustgraph/provenance/triples.py | 13 +++- .../retrieval/graph_rag/graph_rag.py | 65 ++++++++++++++----- 3 files changed, 68 insertions(+), 25 deletions(-) diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 597d3366..00d8b72a 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -465,12 +465,15 @@ class TestQuery: return_value=(["entity1", "entity2"], ["concept1"]) ) - query.follow_edges_batch = AsyncMock(return_value={ - ("entity1", "predicate1", "object1"), - ("entity2", "predicate2", "object2") - }) + query.follow_edges_batch = AsyncMock(return_value=( + { + ("entity1", "predicate1", "object1"), + ("entity2", "predicate2", "object2") + }, + {} + )) - subgraph, entities, concepts = await query.get_subgraph("test query") + subgraph, term_map, entities, concepts = await query.get_subgraph("test query") query.get_entities.assert_called_once_with("test query") query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) @@ -503,7 +506,7 @@ class TestQuery: test_entities = ["entity1", "entity3"] test_concepts = ["concept1"] query.get_subgraph = AsyncMock( - return_value=(test_subgraph, test_entities, test_concepts) + return_value=(test_subgraph, {}, test_entities, test_concepts) ) async def mock_maybe_label(entity): diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index f2e85eff..920a3482 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -465,11 +465,18 @@ def exploration_triples( return triples -def _quoted_triple(s: str, p: str, o: str) -> Term: - """Create a quoted triple term (RDF-star) from string values.""" +def _quoted_triple(s, p, o) -> Term: + """Create a quoted triple term (RDF-star). + + Accepts either Term objects (preserving original types) or plain + strings (treated as IRIs for backward compatibility). + """ + s_term = s if isinstance(s, Term) else _iri(s) + p_term = p if isinstance(p, Term) else _iri(p) + o_term = o if isinstance(o, Term) else _iri(o) return Term( type=TRIPLE, - triple=Triple(s=_iri(s), p=_iri(p), o=_iri(o)) + triple=Triple(s=s_term, p=p_term, o=o_term) ) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 704613c6..5cf7b991 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -10,6 +10,7 @@ from collections import OrderedDict from datetime import datetime from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE +from ... knowledge import Uri, Literal # Provenance imports from trustgraph.provenance import ( @@ -46,6 +47,26 @@ def term_to_string(term): return term.iri or term.value or str(term) +def to_term(val): + """Convert a Uri, Literal, or string to a schema Term. + + The triples client returns Uri/Literal (str subclasses) rather than + Term objects. This converts them back so provenance quoted triples + preserve the correct type. + """ + if isinstance(val, Term): + return val + if isinstance(val, Uri): + return Term(type=IRI, iri=str(val)) + if isinstance(val, Literal): + return Term(type=LITERAL, value=str(val)) + # Fallback: treat as IRI if it looks like one, otherwise literal + s = str(val) + if s.startswith(("http://", "https://", "urn:")): + return Term(type=IRI, iri=s) + return Term(type=LITERAL, value=s) + + def edge_id(s, p, o): """Generate an 8-character hash ID for an edge (s, p, o).""" edge_str = f"{s}|{p}|{o}" @@ -258,10 +279,18 @@ class Query: return all_triples async def follow_edges_batch(self, entities, max_depth): - """Optimized iterative graph traversal with batching""" + """Optimized iterative graph traversal with batching. + + Returns: + tuple: (subgraph, term_map) where subgraph is a set of + (str, str, str) tuples and term_map maps each string tuple + to its original (Term, Term, Term) for type-preserving + provenance. + """ visited = set() current_level = set(entities) subgraph = set() + term_map = {} # (str, str, str) -> (Term, Term, Term) for depth in range(max_depth): if not current_level or len(subgraph) >= self.max_subgraph_size: @@ -282,6 +311,7 @@ class Query: for triple in triples: triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) subgraph.add(triple_tuple) + term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o)) # Collect entities for next level (only from s and o positions) if depth < max_depth - 1: # Don't collect for final depth @@ -293,13 +323,13 @@ class Query: # Stop if subgraph size limit reached if len(subgraph) >= self.max_subgraph_size: - return subgraph + return subgraph, term_map # Update for next iteration visited.update(current_level) current_level = next_level - return subgraph + return subgraph, term_map async def follow_edges(self, ent, subgraph, path_length): """Legacy method - replaced by follow_edges_batch""" @@ -311,7 +341,7 @@ class Query: return # For backward compatibility, convert to new approach - batch_result = await self.follow_edges_batch([ent], path_length) + batch_result, _ = await self.follow_edges_batch([ent], path_length) subgraph.update(batch_result) async def get_subgraph(self, query): @@ -319,9 +349,10 @@ class Query: Get subgraph by extracting concepts, finding entities, and traversing. Returns: - tuple: (subgraph, entities, concepts) where subgraph is a list of - (s, p, o) tuples, entities is the seed entity list, and concepts - is the extracted concept list. + tuple: (subgraph, term_map, entities, concepts) where subgraph is + a list of (s, p, o) string tuples, term_map maps each string + tuple to its original (Term, Term, Term), entities is the seed + entity list, and concepts is the extracted concept list. """ entities, concepts = await self.get_entities(query) @@ -330,9 +361,9 @@ class Query: logger.debug("Getting subgraph...") # Use optimized batch traversal instead of sequential processing - subgraph = await self.follow_edges_batch(entities, self.max_path_length) + subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length) - return list(subgraph), entities, concepts + return list(subgraph), term_map, entities, concepts async def resolve_labels_batch(self, entities): """Resolve labels for multiple entities in parallel""" @@ -353,7 +384,7 @@ class Query: - entities: list of seed entity URI strings - concepts: list of concept strings extracted from query """ - subgraph, entities, concepts = await self.get_subgraph(query) + subgraph, term_map, entities, concepts = await self.get_subgraph(query) # Filter out label triples filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] @@ -377,7 +408,7 @@ class Query: # Apply labels to subgraph and build URI mapping labeled_edges = [] - uri_map = {} # Maps edge_id of labeled edge -> original URI triple + uri_map = {} # Maps edge_id of labeled edge -> original Term triple for s, p, o in filtered_subgraph: labeled_triple = ( @@ -387,9 +418,9 @@ class Query: ) labeled_edges.append(labeled_triple) - # Map from labeled edge ID to original URIs + # Map from labeled edge ID to original Terms (preserving types) labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) - uri_map[labeled_eid] = (s, p, o) + uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o)) labeled_edges = labeled_edges[0:self.max_subgraph_size] @@ -419,12 +450,14 @@ class Query: # Step 1: Find subgraphs containing these edges via tg:contains subgraph_tasks = [] for s, p, o in edge_uris: + # s, p, o may be Term objects (preserving types) or strings + s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s) + p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p) + o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o) quoted = Term( type=TRIPLE, triple=SchemaTriple( - s=Term(type=IRI, iri=s), - p=Term(type=IRI, iri=p), - o=Term(type=IRI, iri=o), + s=s_term, p=p_term, o=o_term, ) ) subgraph_tasks.append(