Merge pull request #712 from trustgraph-ai/release/v2.2

release/v2.2 -> master
This commit is contained in:
cybermaggedon 2026-03-25 17:49:19 +00:00 committed by GitHub
commit 3ccff800c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 3110 additions and 400 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=2.1.999
run: make update-package-versions VERSION=2.2.999
- name: Setup environment
run: python3 -m venv env
@ -39,6 +39,9 @@ jobs:
- name: Install trustgraph-flow
run: (cd trustgraph-flow; pip install .)
- name: Install trustgraph-unstructured
run: (cd trustgraph-unstructured; pip install .)
- name: Install trustgraph-vertexai
run: (cd trustgraph-vertexai; pip install .)

View file

@ -58,6 +58,7 @@ jobs:
- trustgraph-vertexai
- trustgraph-hf
- trustgraph-ocr
- trustgraph-unstructured
- trustgraph-mcp
steps:

1
.gitignore vendored
View file

@ -13,5 +13,6 @@ trustgraph-flow/trustgraph/flow_version.py
trustgraph-ocr/trustgraph/ocr_version.py
trustgraph-parquet/trustgraph/parquet_version.py
trustgraph-vertexai/trustgraph/vertexai_version.py
trustgraph-unstructured/trustgraph/unstructured_version.py
trustgraph-mcp/trustgraph/mcp_version.py
vertexai/

View file

@ -17,6 +17,7 @@ wheels:
pip3 wheel --no-deps --wheel-dir dist trustgraph-embeddings-hf/
pip3 wheel --no-deps --wheel-dir dist trustgraph-cli/
pip3 wheel --no-deps --wheel-dir dist trustgraph-ocr/
pip3 wheel --no-deps --wheel-dir dist trustgraph-unstructured/
pip3 wheel --no-deps --wheel-dir dist trustgraph-mcp/
packages: update-package-versions
@ -29,6 +30,7 @@ packages: update-package-versions
cd trustgraph-embeddings-hf && python -m build --sdist --outdir ../dist/
cd trustgraph-cli && python -m build --sdist --outdir ../dist/
cd trustgraph-ocr && python -m build --sdist --outdir ../dist/
cd trustgraph-unstructured && python -m build --sdist --outdir ../dist/
cd trustgraph-mcp && python -m build --sdist --outdir ../dist/
pypi-upload:
@ -46,6 +48,7 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py
echo __version__ = \"${VERSION}\" > trustgraph-cli/trustgraph/cli_version.py
echo __version__ = \"${VERSION}\" > trustgraph-ocr/trustgraph/ocr_version.py
echo __version__ = \"${VERSION}\" > trustgraph-unstructured/trustgraph/unstructured_version.py
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
@ -64,6 +67,8 @@ containers: FORCE
-t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
${DOCKER} build -f containers/Containerfile.unstructured \
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
${DOCKER} build -f containers/Containerfile.mcp \
-t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
@ -72,6 +77,8 @@ some-containers:
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.unstructured \
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \
@ -98,6 +105,7 @@ push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION}
${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION}
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
# Individual container build targets
@ -119,6 +127,9 @@ container-trustgraph-hf: update-package-versions
container-trustgraph-ocr: update-package-versions
${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
container-trustgraph-unstructured: update-package-versions
${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
container-trustgraph-mcp: update-package-versions
${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
@ -141,6 +152,9 @@ push-trustgraph-hf:
push-trustgraph-ocr:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
push-trustgraph-unstructured:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
push-trustgraph-mcp:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}

View file

@ -0,0 +1,48 @@
# ----------------------------------------------------------------------------
# Base container with system dependencies
# ----------------------------------------------------------------------------
FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
dnf clean all
# ----------------------------------------------------------------------------
# Build a container which contains the built Python packages. The build
# creates a bunch of left-over cruft, a separate phase means this is only
# needed to support package build
# ----------------------------------------------------------------------------
FROM base AS build
COPY trustgraph-base/ /root/build/trustgraph-base/
COPY trustgraph-unstructured/ /root/build/trustgraph-unstructured/
WORKDIR /root/build/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-unstructured/
RUN ls /root/wheels
# ----------------------------------------------------------------------------
# Finally, the target container. Start with base and add the package.
# ----------------------------------------------------------------------------
FROM base
COPY --from=build /root/wheels /root/wheels
RUN \
pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \
pip3 install --no-cache-dir /root/wheels/trustgraph_unstructured-* && \
rm -rf /root/wheels
WORKDIR /

View file

@ -0,0 +1,396 @@
# Universal Document Decoder
## Headline
Universal document decoder powered by `unstructured` — ingest any common
document format through a single service with full provenance and librarian
integration, recording source positions as knowledge graph metadata for
end-to-end traceability.
## Problem
TrustGraph currently has a PDF-specific decoder. Supporting additional
formats (DOCX, XLSX, HTML, Markdown, plain text, PPTX, etc.) requires
either writing a new decoder per format or adopting a universal extraction
library. Each format has different structure — some are page-based, some
aren't — and the provenance chain must record where in the original
document each piece of extracted text originated.
## Approach
### Library: `unstructured`
Use `unstructured.partition.auto.partition()` which auto-detects format
from mime type or file extension and extracts structured elements
(Title, NarrativeText, Table, ListItem, etc.). Each element carries
metadata including:
- `page_number` (for page-based formats like PDF, PPTX)
- `element_id` (unique per element)
- `coordinates` (bounding box for PDFs)
- `text` (the extracted text content)
- `category` (element type: Title, NarrativeText, Table, etc.)
### Element Types
`unstructured` extracts typed elements from documents. Each element has
a category and associated metadata:
**Text elements:**
- `Title` — section headings
- `NarrativeText` — body paragraphs
- `ListItem` — bullet/numbered list items
- `Header`, `Footer` — page headers/footers
- `FigureCaption` — captions for figures/images
- `Formula` — mathematical expressions
- `Address`, `EmailAddress` — contact information
- `CodeSnippet` — code blocks (from markdown)
**Tables:**
- `Table` — structured tabular data. `unstructured` provides both
`element.text` (plain text) and `element.metadata.text_as_html`
(full HTML `<table>` with rows, columns, and headers preserved).
For formats with explicit table structure (DOCX, XLSX, HTML), the
extraction is highly reliable. For PDFs, table detection depends on
the `hi_res` strategy with layout analysis.
**Images:**
- `Image` — embedded images detected via layout analysis (requires
`hi_res` strategy). With `extract_image_block_to_payload=True`,
returns the image data as base64 in `element.metadata.image_base64`.
OCR text from the image is available in `element.text`.
### Table Handling
Tables are a first-class output. When the decoder encounters a `Table`
element, it preserves the HTML structure rather than flattening to
plain text. This gives the downstream LLM extractor much better input
for pulling structured knowledge out of tabular data.
The page/section text is assembled as follows:
- Text elements: plain text, joined with newlines
- Table elements: HTML table markup from `text_as_html`, wrapped in a
`<table>` marker so the LLM can distinguish tables from narrative
For example, a page with a heading, paragraph, and table produces:
```
Financial Overview
Revenue grew 15% year-over-year driven by enterprise adoption.
<table>
<tr><th>Quarter</th><th>Revenue</th><th>Growth</th></tr>
<tr><td>Q1</td><td>$12M</td><td>12%</td></tr>
<tr><td>Q2</td><td>$14M</td><td>17%</td></tr>
</table>
```
This preserves table structure through chunking and into the extraction
pipeline, where the LLM can extract relationships directly from
structured cells rather than guessing at column alignment from
whitespace.
### Image Handling
Images are extracted and stored in the librarian as child documents
with `document_type="image"` and a `urn:image:{uuid}` ID. They get
provenance triples with type `tg:Image`, linked to their parent
page/section via `prov:wasDerivedFrom`. Image metadata (coordinates,
dimensions, element_id) is recorded in provenance.
**Crucially, images are NOT emitted as TextDocument outputs.** They are
stored only — not sent downstream to the chunker or any text processing
pipeline. This is intentional:
1. There is no image processing pipeline yet (vision model integration
is future work)
2. Feeding base64 image data or OCR fragments into the text extraction
pipeline would produce garbage KG triples
Images are also excluded from the assembled page text — any `Image`
elements are silently skipped when concatenating element texts for a
page/section. The provenance chain records that images exist and where
they appeared in the document, so they can be picked up by a future
image processing pipeline without re-ingesting the document.
#### Future work
- Route `tg:Image` entities to a vision model for description,
diagram interpretation, or chart data extraction
- Store image descriptions as text child documents that feed into
the standard chunking/extraction pipeline
- Link extracted knowledge back to source images via provenance
### Section Strategies
For page-based formats (PDF, PPTX, XLSX), elements are always grouped
by page/slide/sheet first. For non-page formats (DOCX, HTML, Markdown,
etc.), the decoder needs a strategy for splitting the document into
sections. This is runtime-configurable via `--section-strategy`.
Each strategy is a grouping function over the list of `unstructured`
elements. The output is a list of element groups; the rest of the
pipeline (text assembly, librarian storage, provenance, TextDocument
emission) is identical regardless of strategy.
#### `whole-document` (default)
Emit the entire document as a single section. Let the downstream
chunker handle all splitting.
- Simplest approach, good baseline
- May produce very large TextDocument for big files, but the chunker
handles that
- Best when you want maximum context per section
#### `heading`
Split at heading elements (`Title`). Each section is a heading plus
all content until the next heading of equal or higher level. Nested
headings create nested sections.
- Produces topically coherent units
- Works well for structured documents (reports, manuals, specs)
- Gives the extraction LLM heading context alongside content
- Falls back to `whole-document` if no headings are found
#### `element-type`
Split when the element type changes significantly — specifically,
start a new section at transitions between narrative text and tables.
Consecutive elements of the same broad category (text, text, text or
table, table) stay grouped.
- Keeps tables as standalone sections
- Good for documents with mixed content (reports with data tables)
- Tables get dedicated extraction attention
#### `count`
Group a fixed number of elements per section. Configurable via
`--section-element-count` (default: 20).
- Simple and predictable
- Doesn't respect document structure
- Useful as a fallback or for experimentation
#### `size`
Accumulate elements until a character limit is reached, then start a
new section. Respects element boundaries — never splits mid-element.
Configurable via `--section-max-size` (default: 4000 characters).
- Produces roughly uniform section sizes
- Respects element boundaries (unlike the downstream chunker)
- Good compromise between structure and size control
- If a single element exceeds the limit, it becomes its own section
#### Page-based format interaction
For page-based formats, the page grouping always takes priority.
Section strategies can optionally apply *within* a page if it's very
large (e.g. a PDF page with an enormous table), controlled by
`--section-within-pages` (default: false). When false, each page is
always one section regardless of size.
### Format Detection
The decoder needs to know the document's mime type to pass to
`unstructured`'s `partition()`. Two paths:
- **Librarian path** (`document_id` set): fetch document metadata
from the librarian first — this gives us the `kind` (mime type)
that was recorded at upload time. Then fetch document content.
Two librarian calls, but the metadata fetch is lightweight.
- **Inline path** (backward compat, `data` set): no metadata
available on the message. Use `python-magic` to detect format
from content bytes as a fallback.
No changes to the `Document` schema are needed — the librarian
already stores the mime type.
### Architecture
A single `universal-decoder` service that:
1. Receives a `Document` message (inline or via librarian reference)
2. If librarian path: fetch document metadata (get mime type), then
fetch content. If inline path: detect format from content bytes.
3. Calls `partition()` to extract elements
4. Groups elements: by page for page-based formats, by configured
section strategy for non-page formats
5. For each page/section:
- Generates a `urn:page:{uuid}` or `urn:section:{uuid}` ID
- Assembles page text: narrative as plain text, tables as HTML,
images skipped
- Computes character offsets for each element within the page text
- Saves to librarian as child document
- Emits provenance triples with positional metadata
- Sends `TextDocument` downstream for chunking
6. For each image element:
- Generates a `urn:image:{uuid}` ID
- Saves image data to librarian as child document
- Emits provenance triples (stored only, not sent downstream)
### Format Handling
| Format | Mime Type | Page-based | Notes |
|----------|------------------------------------|------------|--------------------------------|
| PDF | application/pdf | Yes | Per-page grouping |
| DOCX | application/vnd.openxmlformats... | No | Uses section strategy |
| PPTX | application/vnd.openxmlformats... | Yes | Per-slide grouping |
| XLSX/XLS | application/vnd.openxmlformats... | Yes | Per-sheet grouping |
| HTML | text/html | No | Uses section strategy |
| Markdown | text/markdown | No | Uses section strategy |
| Plain | text/plain | No | Uses section strategy |
| CSV | text/csv | No | Uses section strategy |
| RST | text/x-rst | No | Uses section strategy |
| RTF | application/rtf | No | Uses section strategy |
| ODT | application/vnd.oasis... | No | Uses section strategy |
| TSV | text/tab-separated-values | No | Uses section strategy |
### Provenance Metadata
Each page/section entity records positional metadata as provenance
triples in `GRAPH_SOURCE`, enabling full traceability from KG triples
back to source document positions.
#### Existing fields (already in `derived_entity_triples`)
- `page_number` — page/sheet/slide number (1-indexed, page-based only)
- `char_offset` — character offset of this page/section within the
full document text
- `char_length` — character length of this page/section's text
#### New fields (extend `derived_entity_triples`)
- `mime_type` — original document format (e.g. `application/pdf`)
- `element_types` — comma-separated list of `unstructured` element
categories found in this page/section (e.g. "Title,NarrativeText,Table")
- `table_count` — number of tables in this page/section
- `image_count` — number of images in this page/section
These require new TG namespace predicates:
```
TG_SECTION_TYPE = "https://trustgraph.ai/ns/Section"
TG_IMAGE_TYPE = "https://trustgraph.ai/ns/Image"
TG_ELEMENT_TYPES = "https://trustgraph.ai/ns/elementTypes"
TG_TABLE_COUNT = "https://trustgraph.ai/ns/tableCount"
TG_IMAGE_COUNT = "https://trustgraph.ai/ns/imageCount"
```
Image URN scheme: `urn:image:{uuid}`
(`TG_MIME_TYPE` already exists.)
#### New entity type
For non-page formats (DOCX, HTML, Markdown, etc.) where the decoder
emits the whole document as a single unit rather than splitting by
page, the entity gets a new type:
```
TG_SECTION_TYPE = "https://trustgraph.ai/ns/Section"
```
This distinguishes sections from pages when querying provenance:
| Entity | Type | When used |
|----------|-----------------------------|----------------------------------------|
| Document | `tg:Document` | Original uploaded file |
| Page | `tg:Page` | Page-based formats (PDF, PPTX, XLSX) |
| Section | `tg:Section` | Non-page formats (DOCX, HTML, MD, etc) |
| Image | `tg:Image` | Embedded images (stored, not processed)|
| Chunk | `tg:Chunk` | Output of chunker |
| Subgraph | `tg:Subgraph` | KG extraction output |
The type is set by the decoder based on whether it's grouping by page
or emitting a whole-document section. `derived_entity_triples` gains
an optional `section` boolean parameter — when true, the entity is
typed as `tg:Section` instead of `tg:Page`.
#### Full provenance chain
```
KG triple
→ subgraph (extraction provenance)
→ chunk (char_offset, char_length within page)
→ page/section (page_number, char_offset, char_length within doc, mime_type, element_types)
→ document (original file in librarian)
```
Every link is a set of triples in the `GRAPH_SOURCE` named graph.
### Service Configuration
Command-line arguments:
```
--strategy Partitioning strategy: auto, hi_res, fast (default: auto)
--languages Comma-separated OCR language codes (default: eng)
--section-strategy Section grouping: whole-document, heading, element-type,
count, size (default: whole-document)
--section-element-count Elements per section for 'count' strategy (default: 20)
--section-max-size Max chars per section for 'size' strategy (default: 4000)
--section-within-pages Apply section strategy within pages too (default: false)
```
Plus the standard `FlowProcessor` and librarian queue arguments.
### Flow Integration
The universal decoder occupies the same position in the processing flow
as the current PDF decoder:
```
Document → [universal-decoder] → TextDocument → [chunker] → Chunk → ...
```
It registers:
- `input` consumer (Document schema)
- `output` producer (TextDocument schema)
- `triples` producer (Triples schema)
- Librarian request/response (for fetch and child document storage)
### Deployment
- New container: `trustgraph-flow-universal-decoder`
- Dependency: `unstructured[all-docs]` (includes PDF, DOCX, PPTX, etc.)
- Can run alongside or replace the existing PDF decoder depending on
flow configuration
- The existing PDF decoder remains available for environments where
`unstructured` dependencies are too heavy
### What Changes
| Component | Change |
|------------------------------|-------------------------------------------------|
| `provenance/namespaces.py` | Add `TG_SECTION_TYPE`, `TG_IMAGE_TYPE`, `TG_ELEMENT_TYPES`, `TG_TABLE_COUNT`, `TG_IMAGE_COUNT` |
| `provenance/triples.py` | Add `mime_type`, `element_types`, `table_count`, `image_count` kwargs |
| `provenance/__init__.py` | Export new constants |
| New: `decoding/universal/` | New decoder service module |
| `setup.cfg` / `pyproject` | Add `unstructured[all-docs]` dependency |
| Docker | New container image |
| Flow definitions | Wire universal-decoder as document input |
### What Doesn't Change
- Chunker (receives TextDocument, works as before)
- Downstream extractors (receive Chunk, unchanged)
- Librarian (stores child documents, unchanged)
- Schema (Document, TextDocument, Chunk unchanged)
- Query-time provenance (unchanged)
## Risks
- `unstructured[all-docs]` has heavy dependencies (poppler, tesseract,
libreoffice for some formats). Container image will be larger.
Mitigation: offer a `[light]` variant without OCR/office deps.
- Some formats may produce poor text extraction (scanned PDFs without
OCR, complex XLSX layouts). Mitigation: configurable `strategy`
parameter, and the existing Mistral OCR decoder remains available
for high-quality PDF OCR.
- `unstructured` version updates may change element metadata.
Mitigation: pin version, test extraction quality per format.

View file

@ -10,159 +10,137 @@ from unittest import IsolatedAsyncioTestCase
from io import BytesIO
from trustgraph.decoding.mistral_ocr.processor import Processor
from trustgraph.schema import Document, TextDocument, Metadata
from trustgraph.schema import Document, TextDocument, Metadata, Triples
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
self.pubsub = MagicMock()
self.taskgroup = params.get('taskgroup', MagicMock())
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"""Test Mistral OCR processor functionality"""
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_processor_initialization_with_api_key(self, mock_flow_init, mock_mistral_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_with_api_key(
self, mock_mistral_class, mock_producer, mock_consumer
):
"""Test Mistral OCR processor initialization with API key"""
# Arrange
mock_flow_init.return_value = None
mock_mistral = MagicMock()
mock_mistral_class.return_value = mock_mistral
mock_mistral_class.return_value = MagicMock()
config = {
'id': 'test-mistral-ocr',
'api_key': 'test-api-key',
'taskgroup': AsyncMock()
}
# Act
with patch.object(Processor, 'register_specification') as mock_register:
processor = Processor(**config)
processor = Processor(**config)
# Assert
mock_flow_init.assert_called_once()
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
# Verify register_specification was called twice (consumer and producer)
assert mock_register.call_count == 2
# Check consumer spec
consumer_call = mock_register.call_args_list[0]
consumer_spec = consumer_call[0][0]
assert consumer_spec.name == "input"
assert consumer_spec.schema == Document
assert consumer_spec.handler == processor.on_message
# Check producer spec
producer_call = mock_register.call_args_list[1]
producer_spec = producer_call[0][0]
assert producer_spec.name == "output"
assert producer_spec.schema == TextDocument
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_processor_initialization_without_api_key(self, mock_flow_init):
# Check specs registered: input consumer, output producer, triples producer
consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')]
assert len(consumer_specs) >= 1
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_without_api_key(
self, mock_producer, mock_consumer
):
"""Test Mistral OCR processor initialization without API key raises error"""
# Arrange
mock_flow_init.return_value = None
config = {
'id': 'test-mistral-ocr',
'taskgroup': AsyncMock()
}
# Act & Assert
with patch.object(Processor, 'register_specification'):
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
processor = Processor(**config)
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
Processor(**config)
@patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4')
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_ocr_single_chunk(self, mock_flow_init, mock_mistral_class, mock_uuid):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_ocr_single_chunk(
self, mock_mistral_class, mock_producer, mock_consumer
):
"""Test OCR processing with a single chunk (less than 5 pages)"""
# Arrange
mock_flow_init.return_value = None
mock_uuid.return_value = "test-uuid-1234"
# Mock Mistral client
mock_mistral = MagicMock()
mock_mistral_class.return_value = mock_mistral
# Mock file upload
mock_uploaded_file = MagicMock(id="file-123")
mock_mistral.files.upload.return_value = mock_uploaded_file
# Mock signed URL
mock_signed_url = MagicMock(url="https://example.com/signed-url")
mock_mistral.files.get_signed_url.return_value = mock_signed_url
# Mock OCR response
mock_page = MagicMock(
# Mock OCR response with 2 pages
mock_page1 = MagicMock(
markdown="# Page 1\nContent ![img1](img1)",
images=[MagicMock(id="img1", image_base64="data:image/png;base64,abc123")]
)
mock_ocr_response = MagicMock(pages=[mock_page])
mock_page2 = MagicMock(
markdown="# Page 2\nMore content",
images=[]
)
mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2])
mock_mistral.ocr.process.return_value = mock_ocr_response
# Mock PyPDF
mock_pdf_reader = MagicMock()
mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()] # 3 pages
mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()]
config = {
'id': 'test-mistral-ocr',
'api_key': 'test-api-key',
'taskgroup': AsyncMock()
}
with patch.object(Processor, 'register_specification'):
with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader):
with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class:
mock_pdf_writer = MagicMock()
mock_pdf_writer_class.return_value = mock_pdf_writer
processor = Processor(**config)
# Act
result = processor.ocr(b"fake pdf content")
with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader):
with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class:
mock_pdf_writer = MagicMock()
mock_pdf_writer_class.return_value = mock_pdf_writer
# Assert
assert result == "# Page 1\nContent ![img1](data:image/png;base64,abc123)"
# Verify PDF writer was used to create chunk
processor = Processor(**config)
result = processor.ocr(b"fake pdf content")
# Returns list of (markdown, page_num) tuples
assert len(result) == 2
assert result[0] == ("# Page 1\nContent ![img1](data:image/png;base64,abc123)", 1)
assert result[1] == ("# Page 2\nMore content", 2)
# Verify PDF writer was used
assert mock_pdf_writer.add_page.call_count == 3
mock_pdf_writer.write_stream.assert_called_once()
# Verify Mistral API calls
mock_mistral.files.upload.assert_called_once()
upload_call = mock_mistral.files.upload.call_args[1]
assert upload_call['file']['file_name'] == "test-uuid-1234"
assert upload_call['purpose'] == 'ocr'
mock_mistral.files.get_signed_url.assert_called_once_with(
file_id="file-123", expiry=1
)
mock_mistral.ocr.process.assert_called_once_with(
model="mistral-ocr-latest",
include_image_base64=True,
document={
"type": "document_url",
"document_url": "https://example.com/signed-url",
}
)
mock_mistral.ocr.process.assert_called_once()
@patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4')
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_on_message_success(self, mock_flow_init, mock_mistral_class, mock_uuid):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(
self, mock_mistral_class, mock_producer, mock_consumer
):
"""Test successful message processing"""
# Arrange
mock_flow_init.return_value = None
mock_uuid.return_value = "test-uuid-5678"
# Mock Mistral client with simple OCR response
mock_mistral = MagicMock()
mock_mistral_class.return_value = mock_mistral
# Mock the ocr method to return simple markdown
ocr_result = "# Document Title\nThis is the OCR content"
mock_mistral_class.return_value = MagicMock()
# Mock message
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -170,126 +148,100 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
# Mock flow - needs to be a callable that returns an object with send method
# Mock flow
mock_output_flow = AsyncMock()
mock_flow = MagicMock(return_value=mock_output_flow)
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
config = {
'id': 'test-mistral-ocr',
'api_key': 'test-api-key',
'taskgroup': AsyncMock()
}
with patch.object(Processor, 'register_specification'):
processor = Processor(**config)
# Mock the ocr method
with patch.object(processor, 'ocr', return_value=ocr_result):
# Act
await processor.on_message(mock_msg, None, mock_flow)
processor = Processor(**config)
# Assert
# Verify output was sent
mock_output_flow.send.assert_called_once()
# Check output
call_args = mock_output_flow.send.call_args[0][0]
# Mock ocr to return per-page results
ocr_result = [
("# Page 1\nContent", 1),
("# Page 2\nMore content", 2),
]
# Mock save_child_document
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow)
# Verify output was sent for each page
assert mock_output_flow.send.call_count == 2
# Verify triples were sent for each page
assert mock_triples_flow.send.call_count == 2
# Check output uses UUID-based page URNs
call_args = mock_output_flow.send.call_args_list[0][0][0]
assert isinstance(call_args, TextDocument)
assert call_args.metadata == mock_metadata
assert call_args.text == ocr_result.encode('utf-8')
assert call_args.document_id.startswith("urn:page:")
assert call_args.text == b"" # Content stored in librarian
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_chunks_function(self, mock_flow_init, mock_mistral_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunks_function(self, mock_mistral_class):
"""Test the chunks utility function"""
# Arrange
from trustgraph.decoding.mistral_ocr.processor import chunks
test_list = list(range(12))
# Act
result = list(chunks(test_list, 5))
# Assert
assert len(result) == 3
assert result[0] == [0, 1, 2, 3, 4]
assert result[1] == [5, 6, 7, 8, 9]
assert result[2] == [10, 11]
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_replace_images_in_markdown(self, mock_flow_init, mock_mistral_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_replace_images_in_markdown(self, mock_mistral_class):
"""Test the replace_images_in_markdown function"""
# Arrange
from trustgraph.decoding.mistral_ocr.processor import replace_images_in_markdown
markdown = "# Title\n![image1](image1)\nSome text\n![image2](image2)"
images_dict = {
"image1": "data:image/png;base64,abc123",
"image2": "data:image/png;base64,def456"
}
# Act
result = replace_images_in_markdown(markdown, images_dict)
# Assert
expected = "# Title\n![image1](data:image/png;base64,abc123)\nSome text\n![image2](data:image/png;base64,def456)"
assert result == expected
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_get_combined_markdown(self, mock_flow_init, mock_mistral_class):
"""Test the get_combined_markdown function"""
# Arrange
from trustgraph.decoding.mistral_ocr.processor import get_combined_markdown
from mistralai.models import OCRResponse
# Mock OCR response with multiple pages
mock_page1 = MagicMock(
markdown="# Page 1\n![img1](img1)",
images=[MagicMock(id="img1", image_base64="base64_img1")]
)
mock_page2 = MagicMock(
markdown="# Page 2\n![img2](img2)",
images=[MagicMock(id="img2", image_base64="base64_img2")]
)
mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2])
# Act
result = get_combined_markdown(mock_ocr_response)
# Assert
expected = "# Page 1\n![img1](base64_img1)\n\n# Page 2\n![img2](base64_img2)"
result = replace_images_in_markdown(markdown, images_dict)
expected = "# Title\n![image1](data:image/png;base64,abc123)\nSome text\n![image2](data:image/png;base64,def456)"
assert result == expected
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args):
"""Test add_args adds API key argument"""
# Arrange
"""Test add_args adds expected arguments"""
mock_parser = MagicMock()
# Act
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
mock_parser.add_argument.assert_called_once_with(
'-k', '--api-key',
default=None, # default_api_key is None in test environment
help='Mistral API Key'
)
assert mock_parser.add_argument.call_count == 3
# Check the API key arg is among them
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
assert ('-k', '--api-key') in call_args_list
@patch('trustgraph.decoding.mistral_ocr.processor.Processor.launch')
def test_run(self, mock_launch):
"""Test run function"""
# Act
from trustgraph.decoding.mistral_ocr.processor import run
run()
# Assert
mock_launch.assert_called_once_with("pdf-decoder",
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n")
mock_launch.assert_called_once()
args = mock_launch.call_args[0]
assert args[0] == "document-decoder"
assert "Mistral OCR decoder" in args[1]
if __name__ == '__main__':

View file

@ -171,8 +171,8 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
mock_output_flow.send.assert_called_once()
call_args = mock_output_flow.send.call_args[0][0]
# PDF decoder now forwards document_id, chunker fetches content from librarian
assert call_args.document_id == "test-doc/p1"
# PDF decoder now forwards document_id with UUID-based URN
assert call_args.document_id.startswith("urn:page:")
assert call_args.text == b"" # Content stored in librarian, not inline
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
@ -187,7 +187,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"""Test run function"""
from trustgraph.decoding.pdf.pdf_decoder import run
run()
mock_launch.assert_called_once_with("pdf-decoder",
mock_launch.assert_called_once_with("document-decoder",
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n\nSupports both inline document data and fetching from librarian via Pulsar\nfor large documents.\n")

View file

@ -0,0 +1,412 @@
"""
Unit tests for trustgraph.decoding.universal.processor
"""
import pytest
import base64
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.decoding.universal.processor import (
Processor, assemble_section_text, MIME_EXTENSIONS, PAGE_BASED_FORMATS,
)
from trustgraph.schema import Document, TextDocument, Metadata, Triples
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
self.pubsub = MagicMock()
self.taskgroup = params.get('taskgroup', MagicMock())
def make_element(category="NarrativeText", text="Some text",
page_number=None, text_as_html=None, image_base64=None):
"""Create a mock unstructured element."""
el = MagicMock()
el.category = category
el.text = text
el.metadata = MagicMock()
el.metadata.page_number = page_number
el.metadata.text_as_html = text_as_html
el.metadata.image_base64 = image_base64
return el
class TestAssembleSectionText:
"""Test the text assembly function."""
def test_narrative_text(self):
elements = [
make_element("NarrativeText", "Paragraph one."),
make_element("NarrativeText", "Paragraph two."),
]
text, types, tables, images = assemble_section_text(elements)
assert text == "Paragraph one.\n\nParagraph two."
assert "NarrativeText" in types
assert tables == 0
assert images == 0
def test_table_with_html(self):
elements = [
make_element("NarrativeText", "Before table."),
make_element(
"Table", "Col1 Col2",
text_as_html="<table><tr><td>Col1</td><td>Col2</td></tr></table>"
),
]
text, types, tables, images = assemble_section_text(elements)
assert "<table>" in text
assert "Before table." in text
assert tables == 1
assert "Table" in types
def test_table_without_html_fallback(self):
el = make_element("Table", "plain table text")
el.metadata.text_as_html = None
elements = [el]
text, types, tables, images = assemble_section_text(elements)
assert text == "plain table text"
assert tables == 1
def test_images_skipped(self):
elements = [
make_element("NarrativeText", "Text content"),
make_element("Image", "OCR text from image"),
]
text, types, tables, images = assemble_section_text(elements)
assert "OCR text" not in text
assert "Text content" in text
assert images == 1
assert "Image" in types
def test_empty_elements(self):
text, types, tables, images = assemble_section_text([])
assert text == ""
assert len(types) == 0
assert tables == 0
assert images == 0
def test_mixed_elements(self):
elements = [
make_element("Title", "Section Heading"),
make_element("NarrativeText", "Body text."),
make_element(
"Table", "data",
text_as_html="<table><tr><td>data</td></tr></table>"
),
make_element("Image", "img text"),
make_element("ListItem", "- item one"),
]
text, types, tables, images = assemble_section_text(elements)
assert "Section Heading" in text
assert "Body text." in text
assert "<table>" in text
assert "img text" not in text
assert "- item one" in text
assert tables == 1
assert images == 1
assert {"Title", "NarrativeText", "Table", "Image", "ListItem"} == types
class TestMimeExtensions:
"""Test the mime type to extension mapping."""
def test_pdf_extension(self):
assert MIME_EXTENSIONS["application/pdf"] == ".pdf"
def test_docx_extension(self):
key = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
assert MIME_EXTENSIONS[key] == ".docx"
def test_html_extension(self):
assert MIME_EXTENSIONS["text/html"] == ".html"
class TestPageBasedFormats:
"""Test page-based format detection."""
def test_pdf_is_page_based(self):
assert "application/pdf" in PAGE_BASED_FORMATS
def test_html_is_not_page_based(self):
assert "text/html" not in PAGE_BASED_FORMATS
def test_pptx_is_page_based(self):
pptx = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
assert pptx in PAGE_BASED_FORMATS
class TestUniversalProcessor(IsolatedAsyncioTestCase):
"""Test universal decoder processor."""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(
self, mock_producer, mock_consumer
):
"""Test processor initialization with defaults."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
}
processor = Processor(**config)
assert processor.partition_strategy == "auto"
assert processor.section_strategy_name == "whole-document"
assert processor.section_element_count == 20
assert processor.section_max_size == 4000
# Check specs: input consumer, output producer, triples producer
consumer_specs = [
s for s in processor.specifications if hasattr(s, 'handler')
]
assert len(consumer_specs) >= 1
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_custom_strategy(
self, mock_producer, mock_consumer
):
"""Test processor initialization with custom section strategy."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
'section_strategy': 'heading',
'strategy': 'hi_res',
}
processor = Processor(**config)
assert processor.partition_strategy == "hi_res"
assert processor.section_strategy_name == "heading"
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_group_by_page(self, mock_producer, mock_consumer):
"""Test page grouping of elements."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
}
processor = Processor(**config)
elements = [
make_element("NarrativeText", "Page 1 text", page_number=1),
make_element("NarrativeText", "More page 1", page_number=1),
make_element("NarrativeText", "Page 2 text", page_number=2),
]
result = processor.group_by_page(elements)
assert len(result) == 2
assert result[0][0] == 1 # page number
assert len(result[0][1]) == 2 # 2 elements on page 1
assert result[1][0] == 2
assert len(result[1][1]) == 1
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_inline_non_page(
self, mock_partition, mock_producer, mock_consumer
):
"""Test processing an inline non-page document."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
}
processor = Processor(**config)
# Mock partition to return elements without page numbers
mock_partition.return_value = [
make_element("Title", "Document Title"),
make_element("NarrativeText", "Body text content."),
]
# Mock message with inline data
content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser",
collection="default")
mock_document = Document(
metadata=mock_metadata,
data=base64.b64encode(content).decode('utf-8'),
)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
# Mock flow
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
# Mock save_child_document and magic
processor.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown"
await processor.on_message(mock_msg, None, mock_flow)
# Should emit one section (whole-document strategy)
assert mock_output_flow.send.call_count == 1
assert mock_triples_flow.send.call_count == 1
# Check output
call_args = mock_output_flow.send.call_args[0][0]
assert isinstance(call_args, TextDocument)
assert call_args.document_id.startswith("urn:section:")
assert call_args.text == b""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_page_based(
self, mock_partition, mock_producer, mock_consumer
):
"""Test processing a page-based document."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
}
processor = Processor(**config)
# Mock partition to return elements with page numbers
mock_partition.return_value = [
make_element("NarrativeText", "Page 1 content", page_number=1),
make_element("NarrativeText", "Page 2 content", page_number=2),
]
# Mock message
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
collection="default")
mock_document = Document(
metadata=mock_metadata,
data=base64.b64encode(content).decode('utf-8'),
)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
await processor.on_message(mock_msg, None, mock_flow)
# Should emit two pages
assert mock_output_flow.send.call_count == 2
# Check first output uses page URI
call_args = mock_output_flow.send.call_args_list[0][0][0]
assert call_args.document_id.startswith("urn:page:")
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_images_stored_not_emitted(
self, mock_partition, mock_producer, mock_consumer
):
"""Test that images are stored but not sent to text pipeline."""
config = {
'id': 'test-universal',
'taskgroup': AsyncMock(),
}
processor = Processor(**config)
mock_partition.return_value = [
make_element("NarrativeText", "Some text", page_number=1),
make_element("Image", "img ocr", page_number=1,
image_base64="aW1hZ2VkYXRh"),
]
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
collection="default")
mock_document = Document(
metadata=mock_metadata,
data=base64.b64encode(content).decode('utf-8'),
)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
await processor.on_message(mock_msg, None, mock_flow)
# Only 1 TextDocument output (the page text, not the image)
assert mock_output_flow.send.call_count == 1
# But 2 triples outputs (page provenance + image provenance)
assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image)
assert processor.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args):
"""Test add_args registers all expected arguments."""
mock_parser = MagicMock()
Processor.add_args(mock_parser)
mock_parent_add_args.assert_called_once_with(mock_parser)
# Check key arguments are registered
arg_names = [
c[0] for c in mock_parser.add_argument.call_args_list
]
assert ('--strategy',) in arg_names
assert ('--languages',) in arg_names
assert ('--section-strategy',) in arg_names
assert ('--section-element-count',) in arg_names
assert ('--section-max-size',) in arg_names
assert ('--section-within-pages',) in arg_names
@patch('trustgraph.decoding.universal.processor.Processor.launch')
def test_run(self, mock_launch):
"""Test run function."""
from trustgraph.decoding.universal.processor import run
run()
mock_launch.assert_called_once()
args = mock_launch.call_args[0]
assert args[0] == "document-decoder"
assert "Universal document decoder" in args[1]
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,204 @@
"""
Unit tests for universal decoder section grouping strategies.
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.decoding.universal.strategies import (
group_whole_document,
group_by_heading,
group_by_element_type,
group_by_count,
group_by_size,
get_strategy,
STRATEGIES,
)
def make_element(category="NarrativeText", text="Some text"):
"""Create a mock unstructured element."""
el = MagicMock()
el.category = category
el.text = text
return el
class TestGroupWholeDocument:
def test_empty_input(self):
assert group_whole_document([]) == []
def test_returns_single_group(self):
elements = [make_element() for _ in range(5)]
result = group_whole_document(elements)
assert len(result) == 1
assert len(result[0]) == 5
def test_preserves_all_elements(self):
elements = [make_element(text=f"text-{i}") for i in range(3)]
result = group_whole_document(elements)
assert result[0] == elements
class TestGroupByHeading:
def test_empty_input(self):
assert group_by_heading([]) == []
def test_no_headings_falls_back(self):
elements = [make_element("NarrativeText") for _ in range(3)]
result = group_by_heading(elements)
assert len(result) == 1
assert len(result[0]) == 3
def test_splits_at_headings(self):
elements = [
make_element("Title", "Heading 1"),
make_element("NarrativeText", "Paragraph 1"),
make_element("NarrativeText", "Paragraph 2"),
make_element("Title", "Heading 2"),
make_element("NarrativeText", "Paragraph 3"),
]
result = group_by_heading(elements)
assert len(result) == 2
assert len(result[0]) == 3 # Heading 1 + 2 paragraphs
assert len(result[1]) == 2 # Heading 2 + 1 paragraph
def test_leading_content_before_first_heading(self):
elements = [
make_element("NarrativeText", "Preamble"),
make_element("Title", "Heading 1"),
make_element("NarrativeText", "Content"),
]
result = group_by_heading(elements)
assert len(result) == 2
assert len(result[0]) == 1 # Preamble
assert len(result[1]) == 2 # Heading + content
def test_consecutive_headings(self):
elements = [
make_element("Title", "H1"),
make_element("Title", "H2"),
make_element("NarrativeText", "Content"),
]
result = group_by_heading(elements)
assert len(result) == 2
class TestGroupByElementType:
def test_empty_input(self):
assert group_by_element_type([]) == []
def test_all_same_type(self):
elements = [make_element("NarrativeText") for _ in range(3)]
result = group_by_element_type(elements)
assert len(result) == 1
def test_splits_at_table_boundary(self):
elements = [
make_element("NarrativeText", "Intro"),
make_element("NarrativeText", "More text"),
make_element("Table", "Table data"),
make_element("NarrativeText", "After table"),
]
result = group_by_element_type(elements)
assert len(result) == 3
assert len(result[0]) == 2 # Two narrative elements
assert len(result[1]) == 1 # One table
assert len(result[2]) == 1 # One narrative
def test_consecutive_tables_stay_grouped(self):
elements = [
make_element("Table", "Table 1"),
make_element("Table", "Table 2"),
]
result = group_by_element_type(elements)
assert len(result) == 1
assert len(result[0]) == 2
class TestGroupByCount:
def test_empty_input(self):
assert group_by_count([]) == []
def test_exact_multiple(self):
elements = [make_element() for _ in range(6)]
result = group_by_count(elements, element_count=3)
assert len(result) == 2
assert all(len(g) == 3 for g in result)
def test_remainder_group(self):
elements = [make_element() for _ in range(7)]
result = group_by_count(elements, element_count=3)
assert len(result) == 3
assert len(result[0]) == 3
assert len(result[1]) == 3
assert len(result[2]) == 1
def test_fewer_than_count(self):
elements = [make_element() for _ in range(2)]
result = group_by_count(elements, element_count=10)
assert len(result) == 1
assert len(result[0]) == 2
class TestGroupBySize:
def test_empty_input(self):
assert group_by_size([]) == []
def test_small_elements_grouped(self):
elements = [make_element(text="Hi") for _ in range(5)]
result = group_by_size(elements, max_size=100)
assert len(result) == 1
def test_splits_at_size_limit(self):
elements = [make_element(text="x" * 100) for _ in range(5)]
result = group_by_size(elements, max_size=250)
# 2 elements per group (200 chars), then split
assert len(result) == 3
assert len(result[0]) == 2
assert len(result[1]) == 2
assert len(result[2]) == 1
def test_large_element_own_group(self):
elements = [
make_element(text="small"),
make_element(text="x" * 5000), # Exceeds max
make_element(text="small"),
]
result = group_by_size(elements, max_size=100)
assert len(result) == 3
def test_respects_element_boundaries(self):
# Each element is 50 chars, max is 120
# Should get 2 per group, not split mid-element
elements = [make_element(text="x" * 50) for _ in range(5)]
result = group_by_size(elements, max_size=120)
assert len(result) == 3
assert len(result[0]) == 2
assert len(result[1]) == 2
assert len(result[2]) == 1
class TestGetStrategy:
def test_all_strategies_accessible(self):
for name in STRATEGIES:
fn = get_strategy(name)
assert callable(fn)
def test_unknown_strategy_raises(self):
with pytest.raises(ValueError, match="Unknown section strategy"):
get_strategy("nonexistent")
def test_returns_correct_function(self):
assert get_strategy("whole-document") is group_whole_document
assert get_strategy("heading") is group_by_heading
assert get_strategy("element-type") is group_by_element_type
assert get_strategy("count") is group_by_count
assert get_strategy("size") is group_by_size

View file

@ -121,11 +121,11 @@ class TestBeginUpload:
assert resp.total_chunks == math.ceil(10_000 / 3000)
@pytest.mark.asyncio
async def test_rejects_invalid_kind(self):
async def test_rejects_empty_kind(self):
lib = _make_librarian()
req = _make_begin_request(kind="image/png")
req = _make_begin_request(kind="")
with pytest.raises(RequestError, match="Invalid document kind"):
with pytest.raises(RequestError, match="MIME type.*required"):
await lib.begin_upload(req)
@pytest.mark.asyncio

View file

@ -10,8 +10,7 @@ from trustgraph.provenance.uris import (
_encode_id,
document_uri,
page_uri,
chunk_uri_from_page,
chunk_uri_from_doc,
chunk_uri,
activity_uri,
subgraph_uri,
agent_uri,
@ -60,31 +59,22 @@ class TestDocumentUris:
assert document_uri(iri) == iri
def test_page_uri_format(self):
result = page_uri("https://example.com/doc/123", 5)
assert result == "https://example.com/doc/123/p5"
result = page_uri()
assert result.startswith("urn:page:")
def test_page_uri_page_zero(self):
result = page_uri("https://example.com/doc/123", 0)
assert result == "https://example.com/doc/123/p0"
def test_page_uri_unique(self):
r1 = page_uri()
r2 = page_uri()
assert r1 != r2
def test_chunk_uri_from_page_format(self):
result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
assert result == "https://example.com/doc/123/p2/c3"
def test_chunk_uri_format(self):
result = chunk_uri()
assert result.startswith("urn:chunk:")
def test_chunk_uri_from_doc_format(self):
result = chunk_uri_from_doc("https://example.com/doc/123", 7)
assert result == "https://example.com/doc/123/c7"
def test_page_uri_preserves_doc_iri(self):
doc = "urn:isbn:978-3-16-148410-0"
result = page_uri(doc, 1)
assert result.startswith(doc)
def test_chunk_from_page_hierarchy(self):
"""Chunk URI should contain both page and chunk identifiers."""
result = chunk_uri_from_page("https://example.com/doc", 3, 5)
assert "/p3/" in result
assert result.endswith("/c5")
def test_chunk_uri_unique(self):
r1 = chunk_uri()
r2 = chunk_uri()
assert r1 != r2
class TestActivityAndSubgraphUris:

View file

@ -449,7 +449,7 @@ class FlowInstance:
def graph_rag(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2,
max_path_length=2, edge_score_limit=30, edge_limit=25,
):
"""
Execute graph-based Retrieval-Augmented Generation (RAG) query.
@ -465,6 +465,8 @@ class FlowInstance:
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 150)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
Returns:
str: Generated response incorporating graph context
@ -492,6 +494,8 @@ class FlowInstance:
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
}
return self.request(

View file

@ -699,9 +699,12 @@ class SocketFlowInstance:
query: str,
user: str,
collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
@ -715,9 +718,12 @@ class SocketFlowInstance:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
streaming: Enable streaming mode (default: False)
**kwargs: Additional parameters passed to the service
@ -743,9 +749,12 @@ class SocketFlowInstance:
"query": query,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": streaming
}
request.update(kwargs)
@ -762,9 +771,12 @@ class SocketFlowInstance:
query: str,
user: str,
collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
@ -778,9 +790,12 @@ class SocketFlowInstance:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
**kwargs: Additional parameters passed to the service
Yields:
@ -823,11 +838,14 @@ class SocketFlowInstance:
"query": query,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": True,
"explainable": True, # Enable explainability mode
"explainable": True,
}
request.update(kwargs)

View file

@ -96,8 +96,6 @@ class ChunkingService(FlowProcessor):
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""

View file

@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)),
edge_score_limit=int(data.get("edge-score-limit", 30)),
edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False)
)
@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length,
"edge-score-limit": obj.edge_score_limit,
"edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -9,14 +9,14 @@ Provides helpers for:
Usage example:
from trustgraph.provenance import (
document_uri, page_uri, chunk_uri_from_page,
document_uri, page_uri, chunk_uri,
document_triples, derived_entity_triples,
get_vocabulary_triples,
)
# Generate URIs
doc_uri = document_uri("my-doc-123")
page_uri = page_uri("my-doc-123", page_number=1)
pg_uri = page_uri()
# Build provenance triples
triples = document_triples(
@ -35,8 +35,9 @@ from . uris import (
TRUSTGRAPH_BASE,
document_uri,
page_uri,
chunk_uri_from_page,
chunk_uri_from_doc,
section_uri,
chunk_uri,
image_uri,
activity_uri,
subgraph_uri,
agent_uri,
@ -75,8 +76,10 @@ from . namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_SECTION_TYPE, TG_CHUNK_TYPE,
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
@ -138,8 +141,9 @@ __all__ = [
"TRUSTGRAPH_BASE",
"document_uri",
"page_uri",
"chunk_uri_from_page",
"chunk_uri_from_doc",
"section_uri",
"chunk_uri",
"image_uri",
"activity_uri",
"subgraph_uri",
"agent_uri",
@ -171,8 +175,10 @@ __all__ = [
"TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION",
"TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL",
"TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH",
"TG_ELEMENT_TYPES", "TG_TABLE_COUNT", "TG_IMAGE_COUNT",
# Extraction provenance entity types
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE",
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_SECTION_TYPE",
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",

View file

@ -75,9 +75,16 @@ TG_SELECTED_CHUNK = TG + "selectedChunk"
# Extraction provenance entity types
TG_DOCUMENT_TYPE = TG + "Document"
TG_PAGE_TYPE = TG + "Page"
TG_SECTION_TYPE = TG + "Section"
TG_CHUNK_TYPE = TG + "Chunk"
TG_IMAGE_TYPE = TG + "Image"
TG_SUBGRAPH_TYPE = TG + "Subgraph"
# Universal decoder metadata predicates
TG_ELEMENT_TYPES = TG + "elementTypes"
TG_TABLE_COUNT = TG + "tableCount"
TG_IMAGE_COUNT = TG + "imageCount"
# Explainability entity types (shared)
TG_QUESTION = TG + "Question"
TG_GROUNDING = TG + "Grounding"

View file

@ -18,7 +18,10 @@ from . namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
# Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_SECTION_TYPE, TG_CHUNK_TYPE,
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Universal decoder metadata predicates
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
@ -129,15 +132,22 @@ def derived_entity_triples(
component_version: str,
label: Optional[str] = None,
page_number: Optional[int] = None,
section: bool = False,
image: bool = False,
chunk_index: Optional[int] = None,
char_offset: Optional[int] = None,
char_length: Optional[int] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
mime_type: Optional[str] = None,
element_types: Optional[str] = None,
table_count: Optional[int] = None,
image_count: Optional[int] = None,
timestamp: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a derived entity (page or chunk) with full PROV-O provenance.
Build triples for a derived entity (page, section, chunk, or image)
with full PROV-O provenance.
Creates:
- Entity declaration
@ -146,17 +156,23 @@ def derived_entity_triples(
- Agent for the component
Args:
entity_uri: URI of the derived entity (page or chunk)
entity_uri: URI of the derived entity
parent_uri: URI of the parent entity
component_name: Name of TG component (e.g., "pdf-extractor", "chunker")
component_version: Version of the component
label: Human-readable label
page_number: Page number (for pages)
section: True if this is a document section (non-page format)
image: True if this is an image entity
chunk_index: Chunk index (for chunks)
char_offset: Character offset in parent (for chunks)
char_length: Character length (for chunks)
char_offset: Character offset in parent
char_length: Character length
chunk_size: Configured chunk size (for chunking activity)
chunk_overlap: Configured chunk overlap (for chunking activity)
mime_type: Source document MIME type
element_types: Comma-separated unstructured element categories
table_count: Number of tables in this page/section
image_count: Number of images in this page/section
timestamp: ISO timestamp (defaults to now)
Returns:
@ -169,7 +185,11 @@ def derived_entity_triples(
agt_uri = agent_uri(component_name)
# Determine specific type from parameters
if page_number is not None:
if image:
specific_type = TG_IMAGE_TYPE
elif section:
specific_type = TG_SECTION_TYPE
elif page_number is not None:
specific_type = TG_PAGE_TYPE
elif chunk_index is not None:
specific_type = TG_CHUNK_TYPE
@ -225,6 +245,18 @@ def derived_entity_triples(
if chunk_overlap is not None:
triples.append(_triple(act_uri, TG_CHUNK_OVERLAP, _literal(chunk_overlap)))
if mime_type:
triples.append(_triple(entity_uri, TG_MIME_TYPE, _literal(mime_type)))
if element_types:
triples.append(_triple(entity_uri, TG_ELEMENT_TYPES, _literal(element_types)))
if table_count is not None:
triples.append(_triple(entity_uri, TG_TABLE_COUNT, _literal(table_count)))
if image_count is not None:
triples.append(_triple(entity_uri, TG_IMAGE_COUNT, _literal(image_count)))
return triples

View file

@ -1,12 +1,13 @@
"""
URI generation for provenance entities.
Document IDs are already IRIs (e.g., https://trustgraph.ai/doc/abc123).
Child entities (pages, chunks) append path segments to the parent IRI:
- Document: {doc_iri} (as provided)
- Page: {doc_iri}/p{page_number}
- Chunk: {page_iri}/c{chunk_index} (from page)
{doc_iri}/c{chunk_index} (from text doc)
Document IDs are externally provided (e.g., https://trustgraph.ai/doc/abc123).
Child entities (pages, chunks) use UUID-based URNs:
- Document: {doc_iri} (as provided, not generated here)
- Page: urn:page:{uuid}
- Section: urn:section:{uuid}
- Chunk: urn:chunk:{uuid}
- Image: urn:image:{uuid}
- Activity: https://trustgraph.ai/activity/{uuid}
- Subgraph: https://trustgraph.ai/subgraph/{uuid}
"""
@ -28,19 +29,24 @@ def document_uri(doc_iri: str) -> str:
return doc_iri
def page_uri(doc_iri: str, page_number: int) -> str:
"""Generate URI for a page by appending to document IRI."""
return f"{doc_iri}/p{page_number}"
def page_uri() -> str:
"""Generate a unique URI for a page."""
return f"urn:page:{uuid.uuid4()}"
def chunk_uri_from_page(doc_iri: str, page_number: int, chunk_index: int) -> str:
"""Generate URI for a chunk extracted from a page."""
return f"{doc_iri}/p{page_number}/c{chunk_index}"
def section_uri() -> str:
"""Generate a unique URI for a document section."""
return f"urn:section:{uuid.uuid4()}"
def chunk_uri_from_doc(doc_iri: str, chunk_index: int) -> str:
"""Generate URI for a chunk extracted directly from a text document."""
return f"{doc_iri}/c{chunk_index}"
def chunk_uri() -> str:
"""Generate a unique URI for a chunk."""
return f"urn:chunk:{uuid.uuid4()}"
def image_uri() -> str:
"""Generate a unique URI for an image."""
return f"urn:image:{uuid.uuid4()}"
def activity_uri(activity_id: str = None) -> str:

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
triple_limit: int = 0
max_subgraph_size: int = 0
max_path_length: int = 0
edge_score_limit: int = 0
edge_limit: int = 0
streaming: bool = False

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"requests",
"pulsar-client",
"aiohttp",

View file

@ -28,6 +28,8 @@ default_entity_limit = 50
default_triple_limit = 30
default_max_subgraph_size = 150
default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
@ -638,7 +640,8 @@ async def _question_explainable(
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
max_subgraph_size, max_path_length, edge_score_limit=30,
edge_limit=25, token=None, debug=False
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
@ -652,9 +655,12 @@ def _question_explainable_api(
query=question_text,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_subgraph_count=5,
max_entity_distance=max_path_length,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
):
if isinstance(item, RAGChunk):
# Print response content
@ -743,7 +749,8 @@ def _question_explainable_api(
def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
explainable=False, debug=False
):
@ -759,6 +766,8 @@ def question(
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
token=token,
debug=debug
)
@ -781,6 +790,8 @@ def question(
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
streaming=True
)
@ -801,7 +812,9 @@ def question(
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
)
print(resp)
@ -876,6 +889,20 @@ def main():
help=f'Max path length (default: {default_max_path_length})'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=default_edge_score_limit,
help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})'
)
parser.add_argument(
'--edge-limit',
type=int,
default=default_edge_limit,
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
)
parser.add_argument(
'--no-streaming',
action='store_true',
@ -908,6 +935,8 @@ def main():
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
edge_score_limit=args.edge_score_limit,
edge_limit=args.edge_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,

View file

@ -214,6 +214,12 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None,
chain_str = format_provenance_chain(chain)
if chain_str:
print(f" Source: {chain_str}")
# Show content ID for the chunk (second item in chain)
for item in chain:
uri = item.get("uri", "")
if uri.startswith("urn:chunk:"):
print(f" Content: {uri}")
break
print()
else:

View file

@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-flow>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"trustgraph-flow>=2.2,<2.3",
"torch",
"urllib3",
"transformers",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"aiohttp",
"anthropic",
"scylla-driver",

View file

@ -189,8 +189,6 @@ class Processor(AgentService):
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""

View file

@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
derived_entity_triples,
chunk_uri as make_chunk_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
@ -124,10 +124,9 @@ class Processor(ChunkingService):
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
# Generate unique chunk ID
c_uri = make_chunk_uri()
chunk_doc_id = c_uri
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
@ -145,7 +144,7 @@ class Processor(ChunkingService):
# Emit provenance triples (stored in source graph for separation from core knowledge)
prov_triples = derived_entity_triples(
entity_uri=chunk_uri,
entity_uri=c_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
@ -159,7 +158,7 @@ class Processor(ChunkingService):
await flow("triples").send(Triples(
metadata=Metadata(
id=chunk_uri,
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
@ -170,7 +169,7 @@ class Processor(ChunkingService):
# Forward chunk ID + content (post-chunker optimization)
r = Chunk(
metadata=Metadata(
id=chunk_uri,
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,

View file

@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
derived_entity_triples,
chunk_uri as make_chunk_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
@ -122,10 +122,9 @@ class Processor(ChunkingService):
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
# Generate unique chunk ID
c_uri = make_chunk_uri()
chunk_doc_id = c_uri
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
@ -143,7 +142,7 @@ class Processor(ChunkingService):
# Emit provenance triples (stored in source graph for separation from core knowledge)
prov_triples = derived_entity_triples(
entity_uri=chunk_uri,
entity_uri=c_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
@ -157,7 +156,7 @@ class Processor(ChunkingService):
await flow("triples").send(Triples(
metadata=Metadata(
id=chunk_uri,
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
@ -168,7 +167,7 @@ class Processor(ChunkingService):
# Forward chunk ID + content (post-chunker optimization)
r = Chunk(
metadata=Metadata(
id=chunk_uri,
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,

View file

@ -1,29 +1,48 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
Mistral OCR decoder, accepts PDF documents on input, outputs pages from the
PDF document as markdown text as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents.
"""
from pypdf import PdfWriter, PdfReader
from io import BytesIO
import asyncio
import base64
import uuid
import os
from mistralai import Mistral
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
import logging
logger = logging.getLogger(__name__)
default_ident = "pdf-decoder"
# Component identification for provenance
COMPONENT_NAME = "mistral-ocr-decoder"
COMPONENT_VERSION = "1.0.0"
default_ident = "document-decoder"
default_api_key = os.getenv("MISTRAL_TOKEN")
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
pages_per_chunk = 5
def chunks(lst, n):
@ -48,27 +67,6 @@ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
)
return markdown_str
def get_combined_markdown(ocr_response: OCRResponse) -> str:
"""
Combine OCR text and images into a single markdown document.
Args:
ocr_response: Response from OCR processing containing text and images
Returns:
Combined markdown string with embedded images
"""
markdowns: list[str] = []
# Extract images from page
for page in ocr_response.pages:
image_data = {}
for img in page.images:
image_data[img.id] = img.image_base64
# Replace image placeholders with actual images
markdowns.append(replace_images_in_markdown(page.markdown, image_data))
return "\n\n".join(markdowns)
class Processor(FlowProcessor):
def __init__(self, **params):
@ -97,6 +95,50 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
if api_key is None:
raise RuntimeError("Mistral API key not specified")
@ -107,15 +149,156 @@ class Processor(FlowProcessor):
logger.info("Mistral OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
def ocr(self, blob):
"""
Run Mistral OCR on a PDF blob, returning per-page markdown strings.
Args:
blob: Raw PDF bytes
Returns:
List of (page_markdown, page_number) tuples, 1-indexed
"""
logger.debug("Parse PDF...")
pdfbuf = BytesIO(blob)
pdf = PdfReader(pdfbuf)
pages = []
global_page_num = 0
for chunk in chunks(pdf.pages, pages_per_chunk):
logger.debug("Get next pages...")
part = PdfWriter()
@ -152,11 +335,19 @@ class Processor(FlowProcessor):
logger.debug("Extract markdown...")
markdown = get_combined_markdown(processed)
for page in processed.pages:
global_page_num += 1
image_data = {}
for img in page.images:
image_data[img.id] = img.image_base64
markdown = replace_images_in_markdown(
page.markdown, image_data
)
pages.append((markdown, global_page_num))
logger.info("OCR complete.")
logger.info(f"OCR complete, {len(pages)} pages.")
return markdown
return pages
async def on_message(self, msg, consumer, flow):
@ -166,16 +357,97 @@ class Processor(FlowProcessor):
logger.info(f"Decoding {v.metadata.id}...")
markdown = self.ocr(base64.b64decode(v.data))
# Check MIME type if fetching from librarian
if v.document_id:
doc_meta = await self.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
f"Unsupported MIME type: {doc_meta.kind}. "
f"Mistral OCR decoder only handles application/pdf. "
f"Ignoring document {v.metadata.id}."
)
return
r = TextDocument(
metadata=v.metadata,
text=markdown.encode("utf-8"),
)
# Get PDF content - fetch from librarian or use inline data
if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
)
if isinstance(content, str):
content = content.encode('utf-8')
blob = base64.b64decode(content)
logger.info(f"Fetched {len(blob)} bytes from librarian")
else:
blob = base64.b64decode(v.data)
await flow("output").send(r)
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
logger.info("Done.")
# Run OCR, get per-page markdown
pages = self.ocr(blob)
for markdown, page_num in pages:
logger.debug(f"Processing page {page_num}")
# Generate unique page ID
pg_uri = make_page_uri()
page_doc_id = pg_uri
page_content = markdown.encode("utf-8")
# Save page as child document in librarian
await self.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
content=page_content,
document_type="page",
title=f"Page {page_num}",
)
# Emit provenance triples
doc_uri = document_uri(source_doc_id)
prov_triples = derived_entity_triples(
entity_uri=pg_uri,
parent_uri=doc_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Page {page_num}",
page_number=page_num,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward page document ID to chunker
# Chunker will fetch content from librarian
r = TextDocument(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,
text=b"", # Empty, chunker will fetch from librarian
)
await flow("output").send(r)
logger.debug("PDF decoding complete")
@staticmethod
def add_args(parser):
@ -188,7 +460,18 @@ class Processor(FlowProcessor):
help=f'Mistral API Key'
)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -23,7 +23,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import (
document_uri, page_uri, derived_entity_triples,
document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
@ -34,7 +34,7 @@ COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "pdf-decoder"
default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
@ -126,8 +126,39 @@ class Processor(FlowProcessor):
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
@ -233,6 +264,20 @@ class Processor(FlowProcessor):
logger.info(f"Decoding PDF {v.metadata.id}...")
# Check MIME type if fetching from librarian
if v.document_id:
doc_meta = await self.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
f"Unsupported MIME type: {doc_meta.kind}. "
f"PDF decoder only handles application/pdf. "
f"Ignoring document {v.metadata.id}."
)
return
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
temp_path = fp.name
@ -272,8 +317,9 @@ class Processor(FlowProcessor):
logger.debug(f"Processing page {page_num}")
# Generate page document ID
page_doc_id = f"{source_doc_id}/p{page_num}"
# Generate unique page ID
pg_uri = make_page_uri()
page_doc_id = pg_uri
page_content = page.page_content.encode("utf-8")
# Save page as child document in librarian
@ -288,7 +334,6 @@ class Processor(FlowProcessor):
# Emit provenance triples (stored in source graph for separation from core knowledge)
doc_uri = document_uri(source_doc_id)
pg_uri = page_uri(source_doc_id, page_num)
prov_triples = derived_entity_triples(
entity_uri=pg_uri,

View file

@ -44,12 +44,8 @@ class Librarian:
async def add_document(self, request):
if request.document_metadata.kind not in (
"text/plain", "application/pdf"
):
raise RequestError(
"Invalid document kind: " + request.document_metadata.kind
)
if not request.document_metadata.kind:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,
@ -276,10 +272,8 @@ class Librarian:
"""
logger.info(f"Beginning chunked upload for document {request.document_metadata.id}")
if request.document_metadata.kind not in ("text/plain", "application/pdf"):
raise RequestError(
"Invalid document kind: " + request.document_metadata.kind
)
if not request.document_metadata.kind:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,

View file

@ -284,7 +284,6 @@ class Processor(AsyncProcessor):
pass
# Threshold for sending document_id instead of inline content (2MB)
STREAMING_THRESHOLD = 2 * 1024 * 1024
async def emit_document_provenance(self, document, processing, triples_queue):
"""
@ -360,10 +359,8 @@ class Processor(AsyncProcessor):
if document.kind == "text/plain":
kind = "text-load"
elif document.kind == "application/pdf":
kind = "document-load"
else:
raise RuntimeError("Document with a MIME type I don't know")
kind = "document-load"
q = flow["interfaces"][kind]
@ -374,57 +371,28 @@ class Processor(AsyncProcessor):
)
if kind == "text-load":
# For large text documents, send document_id for streaming retrieval
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Text document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = TextDocument(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
text = b"", # Empty, receiver will fetch via librarian
)
else:
doc = TextDocument(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
text = content,
)
doc = TextDocument(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
text = b"",
)
schema = TextDocument
else:
# For large PDF documents, send document_id for streaming retrieval
# instead of embedding the entire content in the message
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = Document(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
data = b"", # Empty data, receiver will fetch via API
)
else:
doc = Document(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
data = base64.b64encode(content).decode("utf-8")
)
doc = Document(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
data = b"",
)
schema = Document
logger.debug(f"Submitting to queue {q}...")

View file

@ -139,8 +139,6 @@ class Processor(FlowProcessor):
if request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
"""Fetch chunk content from librarian/Garage."""

View file

@ -3,6 +3,7 @@ import asyncio
import hashlib
import json
import logging
import math
import time
import uuid
from collections import OrderedDict
@ -550,7 +551,8 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_limit = 25, streaming = False,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
@ -565,6 +567,8 @@ class GraphRag:
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
@ -628,6 +632,70 @@ class GraphRag:
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
@ -49,6 +50,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit,
}
)
@ -57,6 +59,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
@ -166,8 +169,6 @@ class Processor(FlowProcessor):
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
@ -295,6 +296,11 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit:
edge_limit = v.edge_limit
else:
@ -330,6 +336,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
@ -344,6 +351,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
@ -432,6 +440,20 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
)
# Note: Explainability triples are now stored in the user's collection
# with the named graph urn:graph:retrieval (no separate collection needed)

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -2,21 +2,41 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents.
"""
import tempfile
import asyncio
import base64
import logging
import uuid
import pytesseract
from pdf2image import convert_from_bytes
from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
# Component identification for provenance
COMPONENT_NAME = "tesseract-ocr-decoder"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "pdf-decoder"
default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -45,8 +65,181 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("PDF OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
async def on_message(self, msg, consumer, flow):
logger.info("PDF message received")
@ -55,21 +248,99 @@ class Processor(FlowProcessor):
logger.info(f"Decoding {v.metadata.id}...")
blob = base64.b64decode(v.data)
# Check MIME type if fetching from librarian
if v.document_id:
doc_meta = await self.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
f"Unsupported MIME type: {doc_meta.kind}. "
f"Tesseract OCR decoder only handles application/pdf. "
f"Ignoring document {v.metadata.id}."
)
return
# Get PDF content - fetch from librarian or use inline data
if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
)
if isinstance(content, str):
content = content.encode('utf-8')
blob = base64.b64decode(content)
logger.info(f"Fetched {len(blob)} bytes from librarian")
else:
blob = base64.b64decode(v.data)
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
pages = convert_from_bytes(blob)
for ix, page in enumerate(pages):
page_num = ix + 1 # 1-indexed
try:
text = pytesseract.image_to_string(page, lang='eng')
except Exception as e:
logger.warning(f"Page did not OCR: {e}")
logger.warning(f"Page {page_num} did not OCR: {e}")
continue
logger.debug(f"Processing page {page_num}")
# Generate unique page ID
pg_uri = make_page_uri()
page_doc_id = pg_uri
page_content = text.encode("utf-8")
# Save page as child document in librarian
await self.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
content=page_content,
document_type="page",
title=f"Page {page_num}",
)
# Emit provenance triples
doc_uri = document_uri(source_doc_id)
prov_triples = derived_entity_triples(
entity_uri=pg_uri,
parent_uri=doc_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Page {page_num}",
page_number=page_num,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward page document ID to chunker
# Chunker will fetch content from librarian
r = TextDocument(
metadata=v.metadata,
text=text.encode("utf-8"),
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,
text=b"", # Empty, chunker will fetch from librarian
)
await flow("output").send(r)
@ -78,9 +349,21 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,34 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "trustgraph-unstructured"
dynamic = ["version"]
authors = [{name = "trustgraph.ai", email = "security@trustgraph.ai"}]
description = "TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline."
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.2,<2.3",
"pulsar-client",
"prometheus-client",
"python-magic",
"unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]",
]
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
[project.urls]
Homepage = "https://github.com/trustgraph-ai/trustgraph"
[project.scripts]
universal-decoder = "trustgraph.decoding.universal:run"
[tool.setuptools.packages.find]
include = ["trustgraph*"]
[tool.setuptools.dynamic]
version = {attr = "trustgraph.unstructured_version.__version__"}

View file

@ -0,0 +1,2 @@
from . processor import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,710 @@
"""
Universal document decoder powered by the unstructured library.
Accepts documents in any common format (PDF, DOCX, XLSX, HTML, Markdown,
plain text, PPTX, etc.) on input, outputs pages or sections as text
as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents. Fetches document metadata from the librarian to
determine mime type for format detection.
Tables are preserved as HTML markup for better downstream extraction.
Images are stored in the librarian but not sent to the text pipeline.
"""
import asyncio
import base64
import logging
import magic
import tempfile
import os
import uuid
from unstructured.partition.auto import partition
from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import (
document_uri, page_uri as make_page_uri,
section_uri as make_section_uri, image_uri as make_image_uri,
derived_entity_triples, set_graph, GRAPH_SOURCE,
)
from . strategies import get_strategy
# Component identification for provenance
COMPONENT_NAME = "universal-decoder"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
# Mime type to unstructured content_type mapping
# unstructured auto-detects most formats, but we pass the hint when available
MIME_EXTENSIONS = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"text/html": ".html",
"text/markdown": ".md",
"text/plain": ".txt",
"text/csv": ".csv",
"text/tab-separated-values": ".tsv",
"application/rtf": ".rtf",
"text/x-rst": ".rst",
"application/vnd.oasis.opendocument.text": ".odt",
}
# Formats that have natural page boundaries
PAGE_BASED_FORMATS = {
"application/pdf",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-excel",
}
def assemble_section_text(elements):
"""
Assemble text from a list of unstructured elements.
- Text elements: plain text, joined with double newlines
- Table elements: HTML table markup from text_as_html
- Image elements: skipped (stored separately, not in text output)
Returns:
tuple: (assembled_text, element_types_set, table_count, image_count)
"""
parts = []
element_types = set()
table_count = 0
image_count = 0
for el in elements:
category = getattr(el, 'category', 'UncategorizedText')
element_types.add(category)
if category == 'Image':
image_count += 1
continue # Images are NOT included in text output
if category == 'Table':
table_count += 1
# Prefer HTML representation for tables
html = getattr(el.metadata, 'text_as_html', None) if hasattr(el, 'metadata') else None
if html:
parts.append(html)
else:
# Fallback to plain text
text = getattr(el, 'text', '') or ''
if text:
parts.append(text)
else:
text = getattr(el, 'text', '') or ''
if text:
parts.append(text)
return '\n\n'.join(parts), element_types, table_count, image_count
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
self.partition_strategy = params.get("strategy", "auto")
self.languages = params.get("languages", "eng").split(",")
self.section_strategy_name = params.get(
"section_strategy", "whole-document"
)
self.section_element_count = params.get("section_element_count", 20)
self.section_max_size = params.get("section_max_size", 4000)
self.section_within_pages = params.get("section_within_pages", False)
self.section_strategy = get_strategy(self.section_strategy_name)
super(Processor, self).__init__(
**params | {
"id": id,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=Document,
handler=self.on_message,
)
)
self.register_specification(
ProducerSpec(
name="output",
schema=TextDocument,
)
)
self.register_specification(
ProducerSpec(
name="triples",
schema=Triples,
)
)
# Librarian client for fetching/storing document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("Universal decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def _librarian_request(self, request, timeout=120):
"""Send a request to the librarian and wait for response."""
request_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
return response
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian response")
async def fetch_document_metadata(self, document_id, user):
"""Fetch document metadata from the librarian."""
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
response = await self._librarian_request(request)
return response.document_metadata
async def fetch_document_content(self, document_id, user):
"""Fetch document content from the librarian."""
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
response = await self._librarian_request(request)
return response.content
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None,
kind="text/plain"):
"""Save a child document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
await self._librarian_request(request)
return doc_id
def extract_elements(self, blob, mime_type=None):
"""
Extract elements from a document using unstructured.
Args:
blob: Raw document bytes
mime_type: Optional mime type hint
Returns:
List of unstructured Element objects
"""
# Determine file extension for unstructured
suffix = MIME_EXTENSIONS.get(mime_type, "") if mime_type else ""
if not suffix:
suffix = ".bin"
with tempfile.NamedTemporaryFile(
delete=False, suffix=suffix
) as fp:
fp.write(blob)
temp_path = fp.name
try:
kwargs = {
"filename": temp_path,
"strategy": self.partition_strategy,
"languages": self.languages,
}
# For hi_res strategy, request image extraction
if self.partition_strategy == "hi_res":
kwargs["extract_image_block_to_payload"] = True
elements = partition(**kwargs)
logger.info(
f"Extracted {len(elements)} elements "
f"(strategy: {self.partition_strategy})"
)
return elements
finally:
try:
os.unlink(temp_path)
except OSError:
pass
def group_by_page(self, elements):
"""
Group elements by page number.
Returns list of (page_number, elements) tuples.
"""
pages = {}
for el in elements:
page_num = getattr(
el.metadata, 'page_number', None
) if hasattr(el, 'metadata') else None
if page_num is None:
page_num = 1
if page_num not in pages:
pages[page_num] = []
pages[page_num].append(el)
return sorted(pages.items())
async def emit_section(self, elements, parent_doc_id, doc_uri_str,
metadata, flow, mime_type=None,
page_number=None, section_index=None):
"""
Process a group of elements as a page or section.
Assembles text, saves to librarian, emits provenance, sends
TextDocument downstream. Returns the entity URI.
"""
text, element_types, table_count, image_count = (
assemble_section_text(elements)
)
if not text.strip():
logger.debug("Skipping empty section")
return None
is_page = page_number is not None
char_length = len(text)
if is_page:
entity_uri = make_page_uri()
label = f"Page {page_number}"
else:
entity_uri = make_section_uri()
label = f"Section {section_index}" if section_index else "Section"
doc_id = entity_uri
page_content = text.encode("utf-8")
# Save to librarian
await self.save_child_document(
doc_id=doc_id,
parent_id=parent_doc_id,
user=metadata.user,
content=page_content,
document_type="page" if is_page else "section",
title=label,
)
# Emit provenance triples
element_types_str = ",".join(sorted(element_types)) if element_types else None
prov_triples = derived_entity_triples(
entity_uri=entity_uri,
parent_uri=doc_uri_str,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=label,
page_number=page_number,
section=not is_page,
char_length=char_length,
mime_type=mime_type,
element_types=element_types_str,
table_count=table_count if table_count > 0 else None,
image_count=image_count if image_count > 0 else None,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=entity_uri,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Send TextDocument downstream (chunker will fetch from librarian)
r = TextDocument(
metadata=Metadata(
id=entity_uri,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
document_id=doc_id,
text=b"",
)
await flow("output").send(r)
return entity_uri
async def emit_image(self, element, parent_uri, parent_doc_id,
metadata, flow, mime_type=None, page_number=None):
"""
Store an image element in the librarian with provenance.
Images are stored but NOT sent downstream to the text pipeline.
"""
img_uri = make_image_uri()
# Get image data
img_data = None
if hasattr(element, 'metadata'):
img_data = getattr(element.metadata, 'image_base64', None)
if not img_data:
# No image payload available, just record provenance
logger.debug("Image element without payload, recording provenance only")
img_content = b""
img_kind = "image/unknown"
else:
if isinstance(img_data, str):
img_content = base64.b64decode(img_data)
else:
img_content = img_data
img_kind = "image/png" # unstructured typically extracts as PNG
# Save to librarian
if img_content:
await self.save_child_document(
doc_id=img_uri,
parent_id=parent_doc_id,
user=metadata.user,
content=img_content,
document_type="image",
title=f"Image from page {page_number}" if page_number else "Image",
kind=img_kind,
)
# Emit provenance triples
prov_triples = derived_entity_triples(
entity_uri=img_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Image from page {page_number}" if page_number else "Image",
image=True,
page_number=page_number,
mime_type=mime_type,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=img_uri,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
async def on_message(self, msg, consumer, flow):
logger.debug("Document message received")
v = msg.value()
logger.info(f"Decoding {v.metadata.id}...")
# Determine content and mime type
mime_type = None
if v.document_id:
# Librarian path: fetch metadata then content
logger.info(
f"Fetching document {v.document_id} from librarian..."
)
doc_meta = await self.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
)
mime_type = doc_meta.kind if doc_meta else None
content = await self.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
)
if isinstance(content, str):
content = content.encode('utf-8')
blob = base64.b64decode(content)
logger.info(
f"Fetched {len(blob)} bytes, mime: {mime_type}"
)
else:
# Inline path: detect format from content
blob = base64.b64decode(v.data)
try:
mime_type = magic.from_buffer(blob, mime=True)
logger.info(f"Detected mime type: {mime_type}")
except Exception as e:
logger.warning(f"Could not detect mime type: {e}")
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
doc_uri_str = document_uri(source_doc_id)
# Extract elements using unstructured
elements = self.extract_elements(blob, mime_type)
if not elements:
logger.warning("No elements extracted from document")
return
# Determine if this is a page-based format
is_page_based = mime_type in PAGE_BASED_FORMATS if mime_type else False
# Also check if elements actually have page numbers
if not is_page_based:
has_pages = any(
getattr(el.metadata, 'page_number', None) is not None
for el in elements
if hasattr(el, 'metadata')
)
if has_pages:
is_page_based = True
if is_page_based:
# Group by page
page_groups = self.group_by_page(elements)
for page_num, page_elements in page_groups:
# Extract and store images separately
image_elements = [
el for el in page_elements
if getattr(el, 'category', '') == 'Image'
]
text_elements = [
el for el in page_elements
if getattr(el, 'category', '') != 'Image'
]
# Emit the page as a text section
page_uri_str = await self.emit_section(
text_elements, source_doc_id, doc_uri_str,
v.metadata, flow,
mime_type=mime_type, page_number=page_num,
)
# Store images (not sent to text pipeline)
for img_el in image_elements:
await self.emit_image(
img_el,
page_uri_str or doc_uri_str,
source_doc_id,
v.metadata, flow,
mime_type=mime_type, page_number=page_num,
)
else:
# Non-page format: use section strategy
# Separate images from text elements
image_elements = [
el for el in elements
if getattr(el, 'category', '') == 'Image'
]
text_elements = [
el for el in elements
if getattr(el, 'category', '') != 'Image'
]
# Apply section strategy to text elements
strategy_kwargs = {
'element_count': self.section_element_count,
'max_size': self.section_max_size,
}
groups = self.section_strategy(text_elements, **strategy_kwargs)
for idx, group in enumerate(groups):
section_idx = idx + 1
await self.emit_section(
group, source_doc_id, doc_uri_str,
v.metadata, flow,
mime_type=mime_type, section_index=section_idx,
)
# Store images (not sent to text pipeline)
for img_el in image_elements:
await self.emit_image(
img_el, doc_uri_str, source_doc_id,
v.metadata, flow,
mime_type=mime_type,
)
logger.info("Document decoding complete")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--strategy',
default='auto',
choices=['auto', 'hi_res', 'fast'],
help='Partitioning strategy (default: auto)',
)
parser.add_argument(
'--languages',
default='eng',
help='Comma-separated OCR language codes (default: eng)',
)
parser.add_argument(
'--section-strategy',
default='whole-document',
choices=[
'whole-document', 'heading', 'element-type', 'count', 'size'
],
help='Section grouping strategy for non-page formats '
'(default: whole-document)',
)
parser.add_argument(
'--section-element-count',
type=int,
default=20,
help='Elements per section for count strategy (default: 20)',
)
parser.add_argument(
'--section-max-size',
type=int,
default=4000,
help='Max chars per section for size strategy (default: 4000)',
)
parser.add_argument(
'--section-within-pages',
action='store_true',
default=False,
help='Apply section strategy within pages too (default: false)',
)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue '
f'(default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue '
f'(default: {default_librarian_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,171 @@
"""
Section grouping strategies for the universal document decoder.
Each strategy takes a list of unstructured elements and returns a list
of element groups. Each group becomes one TextDocument output.
"""
import logging
logger = logging.getLogger(__name__)
def group_whole_document(elements, **kwargs):
"""
Emit the entire document as a single section.
The downstream chunker handles all splitting.
"""
if not elements:
return []
return [elements]
def group_by_heading(elements, **kwargs):
"""
Split at heading elements (Title category).
Each section is a heading plus all content until the next heading.
Falls back to whole-document if no headings are found.
"""
if not elements:
return []
# Check if any headings exist
has_headings = any(
getattr(el, 'category', '') == 'Title' for el in elements
)
if not has_headings:
logger.debug("No headings found, falling back to whole-document")
return group_whole_document(elements)
groups = []
current_group = []
for el in elements:
if getattr(el, 'category', '') == 'Title' and current_group:
groups.append(current_group)
current_group = []
current_group.append(el)
if current_group:
groups.append(current_group)
return groups
def group_by_element_type(elements, **kwargs):
"""
Split on transitions between narrative text and tables.
Consecutive elements of the same broad category stay grouped.
"""
if not elements:
return []
def is_table(el):
return getattr(el, 'category', '') == 'Table'
groups = []
current_group = []
current_is_table = None
for el in elements:
el_is_table = is_table(el)
if current_is_table is not None and el_is_table != current_is_table:
groups.append(current_group)
current_group = []
current_group.append(el)
current_is_table = el_is_table
if current_group:
groups.append(current_group)
return groups
def group_by_count(elements, element_count=20, **kwargs):
"""
Group a fixed number of elements per section.
Args:
elements: List of unstructured elements
element_count: Number of elements per group (default: 20)
"""
if not elements:
return []
groups = []
for i in range(0, len(elements), element_count):
groups.append(elements[i:i + element_count])
return groups
def group_by_size(elements, max_size=4000, **kwargs):
"""
Accumulate elements until a character limit is reached.
Respects element boundaries never splits mid-element. If a
single element exceeds the limit, it becomes its own section.
Args:
elements: List of unstructured elements
max_size: Max characters per section (default: 4000)
"""
if not elements:
return []
groups = []
current_group = []
current_size = 0
for el in elements:
el_text = getattr(el, 'text', '') or ''
el_size = len(el_text)
if current_group and current_size + el_size > max_size:
groups.append(current_group)
current_group = []
current_size = 0
current_group.append(el)
current_size += el_size
if current_group:
groups.append(current_group)
return groups
# Strategy registry
STRATEGIES = {
'whole-document': group_whole_document,
'heading': group_by_heading,
'element-type': group_by_element_type,
'count': group_by_count,
'size': group_by_size,
}
def get_strategy(name):
"""
Get a section grouping strategy by name.
Args:
name: Strategy name (whole-document, heading, element-type, count, size)
Returns:
Strategy function
Raises:
ValueError: If strategy name is not recognized
"""
if name not in STRATEGIES:
raise ValueError(
f"Unknown section strategy: {name}. "
f"Available: {', '.join(STRATEGIES.keys())}"
)
return STRATEGIES[name]

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.1,<2.2",
"trustgraph-base>=2.2,<2.3",
"pulsar-client",
"google-genai",
"google-api-core",