mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge pull request #712 from trustgraph-ai/release/v2.2
release/v2.2 -> master
This commit is contained in:
commit
3ccff800c7
45 changed files with 3110 additions and 400 deletions
5
.github/workflows/pull-request.yaml
vendored
5
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=2.1.999
|
||||
run: make update-package-versions VERSION=2.2.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
@ -39,6 +39,9 @@ jobs:
|
|||
- name: Install trustgraph-flow
|
||||
run: (cd trustgraph-flow; pip install .)
|
||||
|
||||
- name: Install trustgraph-unstructured
|
||||
run: (cd trustgraph-unstructured; pip install .)
|
||||
|
||||
- name: Install trustgraph-vertexai
|
||||
run: (cd trustgraph-vertexai; pip install .)
|
||||
|
||||
|
|
|
|||
1
.github/workflows/release.yaml
vendored
1
.github/workflows/release.yaml
vendored
|
|
@ -58,6 +58,7 @@ jobs:
|
|||
- trustgraph-vertexai
|
||||
- trustgraph-hf
|
||||
- trustgraph-ocr
|
||||
- trustgraph-unstructured
|
||||
- trustgraph-mcp
|
||||
|
||||
steps:
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -13,5 +13,6 @@ trustgraph-flow/trustgraph/flow_version.py
|
|||
trustgraph-ocr/trustgraph/ocr_version.py
|
||||
trustgraph-parquet/trustgraph/parquet_version.py
|
||||
trustgraph-vertexai/trustgraph/vertexai_version.py
|
||||
trustgraph-unstructured/trustgraph/unstructured_version.py
|
||||
trustgraph-mcp/trustgraph/mcp_version.py
|
||||
vertexai/
|
||||
14
Makefile
14
Makefile
|
|
@ -17,6 +17,7 @@ wheels:
|
|||
pip3 wheel --no-deps --wheel-dir dist trustgraph-embeddings-hf/
|
||||
pip3 wheel --no-deps --wheel-dir dist trustgraph-cli/
|
||||
pip3 wheel --no-deps --wheel-dir dist trustgraph-ocr/
|
||||
pip3 wheel --no-deps --wheel-dir dist trustgraph-unstructured/
|
||||
pip3 wheel --no-deps --wheel-dir dist trustgraph-mcp/
|
||||
|
||||
packages: update-package-versions
|
||||
|
|
@ -29,6 +30,7 @@ packages: update-package-versions
|
|||
cd trustgraph-embeddings-hf && python -m build --sdist --outdir ../dist/
|
||||
cd trustgraph-cli && python -m build --sdist --outdir ../dist/
|
||||
cd trustgraph-ocr && python -m build --sdist --outdir ../dist/
|
||||
cd trustgraph-unstructured && python -m build --sdist --outdir ../dist/
|
||||
cd trustgraph-mcp && python -m build --sdist --outdir ../dist/
|
||||
|
||||
pypi-upload:
|
||||
|
|
@ -46,6 +48,7 @@ update-package-versions:
|
|||
echo __version__ = \"${VERSION}\" > trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-cli/trustgraph/cli_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-ocr/trustgraph/ocr_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-unstructured/trustgraph/unstructured_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
|
||||
|
||||
|
|
@ -64,6 +67,8 @@ containers: FORCE
|
|||
-t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.ocr \
|
||||
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.mcp \
|
||||
-t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
|
||||
|
|
@ -72,6 +77,8 @@ some-containers:
|
|||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||
|
|
@ -98,6 +105,7 @@ push:
|
|||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION}
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION}
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
|
||||
|
||||
# Individual container build targets
|
||||
|
|
@ -119,6 +127,9 @@ container-trustgraph-hf: update-package-versions
|
|||
container-trustgraph-ocr: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
|
||||
|
||||
container-trustgraph-unstructured: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
|
||||
container-trustgraph-mcp: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
|
||||
|
|
@ -141,6 +152,9 @@ push-trustgraph-hf:
|
|||
push-trustgraph-ocr:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
|
||||
|
||||
push-trustgraph-unstructured:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
|
||||
|
||||
push-trustgraph-mcp:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
|
||||
|
||||
|
|
|
|||
48
containers/Containerfile.unstructured
Normal file
48
containers/Containerfile.unstructured
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Base container with system dependencies
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
FROM docker.io/fedora:42 AS base
|
||||
|
||||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||
|
||||
RUN dnf install -y python3.13 && \
|
||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||
python -m ensurepip --upgrade && \
|
||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
||||
dnf clean all
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Build a container which contains the built Python packages. The build
|
||||
# creates a bunch of left-over cruft, a separate phase means this is only
|
||||
# needed to support package build
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
FROM base AS build
|
||||
|
||||
COPY trustgraph-base/ /root/build/trustgraph-base/
|
||||
COPY trustgraph-unstructured/ /root/build/trustgraph-unstructured/
|
||||
|
||||
WORKDIR /root/build/
|
||||
|
||||
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/
|
||||
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-unstructured/
|
||||
|
||||
RUN ls /root/wheels
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Finally, the target container. Start with base and add the package.
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
FROM base
|
||||
|
||||
COPY --from=build /root/wheels /root/wheels
|
||||
|
||||
RUN \
|
||||
pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \
|
||||
pip3 install --no-cache-dir /root/wheels/trustgraph_unstructured-* && \
|
||||
rm -rf /root/wheels
|
||||
|
||||
WORKDIR /
|
||||
396
docs/tech-specs/universal-decoder.md
Normal file
396
docs/tech-specs/universal-decoder.md
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
# Universal Document Decoder
|
||||
|
||||
## Headline
|
||||
|
||||
Universal document decoder powered by `unstructured` — ingest any common
|
||||
document format through a single service with full provenance and librarian
|
||||
integration, recording source positions as knowledge graph metadata for
|
||||
end-to-end traceability.
|
||||
|
||||
## Problem
|
||||
|
||||
TrustGraph currently has a PDF-specific decoder. Supporting additional
|
||||
formats (DOCX, XLSX, HTML, Markdown, plain text, PPTX, etc.) requires
|
||||
either writing a new decoder per format or adopting a universal extraction
|
||||
library. Each format has different structure — some are page-based, some
|
||||
aren't — and the provenance chain must record where in the original
|
||||
document each piece of extracted text originated.
|
||||
|
||||
## Approach
|
||||
|
||||
### Library: `unstructured`
|
||||
|
||||
Use `unstructured.partition.auto.partition()` which auto-detects format
|
||||
from mime type or file extension and extracts structured elements
|
||||
(Title, NarrativeText, Table, ListItem, etc.). Each element carries
|
||||
metadata including:
|
||||
|
||||
- `page_number` (for page-based formats like PDF, PPTX)
|
||||
- `element_id` (unique per element)
|
||||
- `coordinates` (bounding box for PDFs)
|
||||
- `text` (the extracted text content)
|
||||
- `category` (element type: Title, NarrativeText, Table, etc.)
|
||||
|
||||
### Element Types
|
||||
|
||||
`unstructured` extracts typed elements from documents. Each element has
|
||||
a category and associated metadata:
|
||||
|
||||
**Text elements:**
|
||||
- `Title` — section headings
|
||||
- `NarrativeText` — body paragraphs
|
||||
- `ListItem` — bullet/numbered list items
|
||||
- `Header`, `Footer` — page headers/footers
|
||||
- `FigureCaption` — captions for figures/images
|
||||
- `Formula` — mathematical expressions
|
||||
- `Address`, `EmailAddress` — contact information
|
||||
- `CodeSnippet` — code blocks (from markdown)
|
||||
|
||||
**Tables:**
|
||||
- `Table` — structured tabular data. `unstructured` provides both
|
||||
`element.text` (plain text) and `element.metadata.text_as_html`
|
||||
(full HTML `<table>` with rows, columns, and headers preserved).
|
||||
For formats with explicit table structure (DOCX, XLSX, HTML), the
|
||||
extraction is highly reliable. For PDFs, table detection depends on
|
||||
the `hi_res` strategy with layout analysis.
|
||||
|
||||
**Images:**
|
||||
- `Image` — embedded images detected via layout analysis (requires
|
||||
`hi_res` strategy). With `extract_image_block_to_payload=True`,
|
||||
returns the image data as base64 in `element.metadata.image_base64`.
|
||||
OCR text from the image is available in `element.text`.
|
||||
|
||||
### Table Handling
|
||||
|
||||
Tables are a first-class output. When the decoder encounters a `Table`
|
||||
element, it preserves the HTML structure rather than flattening to
|
||||
plain text. This gives the downstream LLM extractor much better input
|
||||
for pulling structured knowledge out of tabular data.
|
||||
|
||||
The page/section text is assembled as follows:
|
||||
- Text elements: plain text, joined with newlines
|
||||
- Table elements: HTML table markup from `text_as_html`, wrapped in a
|
||||
`<table>` marker so the LLM can distinguish tables from narrative
|
||||
|
||||
For example, a page with a heading, paragraph, and table produces:
|
||||
|
||||
```
|
||||
Financial Overview
|
||||
|
||||
Revenue grew 15% year-over-year driven by enterprise adoption.
|
||||
|
||||
<table>
|
||||
<tr><th>Quarter</th><th>Revenue</th><th>Growth</th></tr>
|
||||
<tr><td>Q1</td><td>$12M</td><td>12%</td></tr>
|
||||
<tr><td>Q2</td><td>$14M</td><td>17%</td></tr>
|
||||
</table>
|
||||
```
|
||||
|
||||
This preserves table structure through chunking and into the extraction
|
||||
pipeline, where the LLM can extract relationships directly from
|
||||
structured cells rather than guessing at column alignment from
|
||||
whitespace.
|
||||
|
||||
### Image Handling
|
||||
|
||||
Images are extracted and stored in the librarian as child documents
|
||||
with `document_type="image"` and a `urn:image:{uuid}` ID. They get
|
||||
provenance triples with type `tg:Image`, linked to their parent
|
||||
page/section via `prov:wasDerivedFrom`. Image metadata (coordinates,
|
||||
dimensions, element_id) is recorded in provenance.
|
||||
|
||||
**Crucially, images are NOT emitted as TextDocument outputs.** They are
|
||||
stored only — not sent downstream to the chunker or any text processing
|
||||
pipeline. This is intentional:
|
||||
|
||||
1. There is no image processing pipeline yet (vision model integration
|
||||
is future work)
|
||||
2. Feeding base64 image data or OCR fragments into the text extraction
|
||||
pipeline would produce garbage KG triples
|
||||
|
||||
Images are also excluded from the assembled page text — any `Image`
|
||||
elements are silently skipped when concatenating element texts for a
|
||||
page/section. The provenance chain records that images exist and where
|
||||
they appeared in the document, so they can be picked up by a future
|
||||
image processing pipeline without re-ingesting the document.
|
||||
|
||||
#### Future work
|
||||
|
||||
- Route `tg:Image` entities to a vision model for description,
|
||||
diagram interpretation, or chart data extraction
|
||||
- Store image descriptions as text child documents that feed into
|
||||
the standard chunking/extraction pipeline
|
||||
- Link extracted knowledge back to source images via provenance
|
||||
|
||||
### Section Strategies
|
||||
|
||||
For page-based formats (PDF, PPTX, XLSX), elements are always grouped
|
||||
by page/slide/sheet first. For non-page formats (DOCX, HTML, Markdown,
|
||||
etc.), the decoder needs a strategy for splitting the document into
|
||||
sections. This is runtime-configurable via `--section-strategy`.
|
||||
|
||||
Each strategy is a grouping function over the list of `unstructured`
|
||||
elements. The output is a list of element groups; the rest of the
|
||||
pipeline (text assembly, librarian storage, provenance, TextDocument
|
||||
emission) is identical regardless of strategy.
|
||||
|
||||
#### `whole-document` (default)
|
||||
|
||||
Emit the entire document as a single section. Let the downstream
|
||||
chunker handle all splitting.
|
||||
|
||||
- Simplest approach, good baseline
|
||||
- May produce very large TextDocument for big files, but the chunker
|
||||
handles that
|
||||
- Best when you want maximum context per section
|
||||
|
||||
#### `heading`
|
||||
|
||||
Split at heading elements (`Title`). Each section is a heading plus
|
||||
all content until the next heading of equal or higher level. Nested
|
||||
headings create nested sections.
|
||||
|
||||
- Produces topically coherent units
|
||||
- Works well for structured documents (reports, manuals, specs)
|
||||
- Gives the extraction LLM heading context alongside content
|
||||
- Falls back to `whole-document` if no headings are found
|
||||
|
||||
#### `element-type`
|
||||
|
||||
Split when the element type changes significantly — specifically,
|
||||
start a new section at transitions between narrative text and tables.
|
||||
Consecutive elements of the same broad category (text, text, text or
|
||||
table, table) stay grouped.
|
||||
|
||||
- Keeps tables as standalone sections
|
||||
- Good for documents with mixed content (reports with data tables)
|
||||
- Tables get dedicated extraction attention
|
||||
|
||||
#### `count`
|
||||
|
||||
Group a fixed number of elements per section. Configurable via
|
||||
`--section-element-count` (default: 20).
|
||||
|
||||
- Simple and predictable
|
||||
- Doesn't respect document structure
|
||||
- Useful as a fallback or for experimentation
|
||||
|
||||
#### `size`
|
||||
|
||||
Accumulate elements until a character limit is reached, then start a
|
||||
new section. Respects element boundaries — never splits mid-element.
|
||||
Configurable via `--section-max-size` (default: 4000 characters).
|
||||
|
||||
- Produces roughly uniform section sizes
|
||||
- Respects element boundaries (unlike the downstream chunker)
|
||||
- Good compromise between structure and size control
|
||||
- If a single element exceeds the limit, it becomes its own section
|
||||
|
||||
#### Page-based format interaction
|
||||
|
||||
For page-based formats, the page grouping always takes priority.
|
||||
Section strategies can optionally apply *within* a page if it's very
|
||||
large (e.g. a PDF page with an enormous table), controlled by
|
||||
`--section-within-pages` (default: false). When false, each page is
|
||||
always one section regardless of size.
|
||||
|
||||
### Format Detection
|
||||
|
||||
The decoder needs to know the document's mime type to pass to
|
||||
`unstructured`'s `partition()`. Two paths:
|
||||
|
||||
- **Librarian path** (`document_id` set): fetch document metadata
|
||||
from the librarian first — this gives us the `kind` (mime type)
|
||||
that was recorded at upload time. Then fetch document content.
|
||||
Two librarian calls, but the metadata fetch is lightweight.
|
||||
- **Inline path** (backward compat, `data` set): no metadata
|
||||
available on the message. Use `python-magic` to detect format
|
||||
from content bytes as a fallback.
|
||||
|
||||
No changes to the `Document` schema are needed — the librarian
|
||||
already stores the mime type.
|
||||
|
||||
### Architecture
|
||||
|
||||
A single `universal-decoder` service that:
|
||||
|
||||
1. Receives a `Document` message (inline or via librarian reference)
|
||||
2. If librarian path: fetch document metadata (get mime type), then
|
||||
fetch content. If inline path: detect format from content bytes.
|
||||
3. Calls `partition()` to extract elements
|
||||
4. Groups elements: by page for page-based formats, by configured
|
||||
section strategy for non-page formats
|
||||
5. For each page/section:
|
||||
- Generates a `urn:page:{uuid}` or `urn:section:{uuid}` ID
|
||||
- Assembles page text: narrative as plain text, tables as HTML,
|
||||
images skipped
|
||||
- Computes character offsets for each element within the page text
|
||||
- Saves to librarian as child document
|
||||
- Emits provenance triples with positional metadata
|
||||
- Sends `TextDocument` downstream for chunking
|
||||
6. For each image element:
|
||||
- Generates a `urn:image:{uuid}` ID
|
||||
- Saves image data to librarian as child document
|
||||
- Emits provenance triples (stored only, not sent downstream)
|
||||
|
||||
### Format Handling
|
||||
|
||||
| Format | Mime Type | Page-based | Notes |
|
||||
|----------|------------------------------------|------------|--------------------------------|
|
||||
| PDF | application/pdf | Yes | Per-page grouping |
|
||||
| DOCX | application/vnd.openxmlformats... | No | Uses section strategy |
|
||||
| PPTX | application/vnd.openxmlformats... | Yes | Per-slide grouping |
|
||||
| XLSX/XLS | application/vnd.openxmlformats... | Yes | Per-sheet grouping |
|
||||
| HTML | text/html | No | Uses section strategy |
|
||||
| Markdown | text/markdown | No | Uses section strategy |
|
||||
| Plain | text/plain | No | Uses section strategy |
|
||||
| CSV | text/csv | No | Uses section strategy |
|
||||
| RST | text/x-rst | No | Uses section strategy |
|
||||
| RTF | application/rtf | No | Uses section strategy |
|
||||
| ODT | application/vnd.oasis... | No | Uses section strategy |
|
||||
| TSV | text/tab-separated-values | No | Uses section strategy |
|
||||
|
||||
### Provenance Metadata
|
||||
|
||||
Each page/section entity records positional metadata as provenance
|
||||
triples in `GRAPH_SOURCE`, enabling full traceability from KG triples
|
||||
back to source document positions.
|
||||
|
||||
#### Existing fields (already in `derived_entity_triples`)
|
||||
|
||||
- `page_number` — page/sheet/slide number (1-indexed, page-based only)
|
||||
- `char_offset` — character offset of this page/section within the
|
||||
full document text
|
||||
- `char_length` — character length of this page/section's text
|
||||
|
||||
#### New fields (extend `derived_entity_triples`)
|
||||
|
||||
- `mime_type` — original document format (e.g. `application/pdf`)
|
||||
- `element_types` — comma-separated list of `unstructured` element
|
||||
categories found in this page/section (e.g. "Title,NarrativeText,Table")
|
||||
- `table_count` — number of tables in this page/section
|
||||
- `image_count` — number of images in this page/section
|
||||
|
||||
These require new TG namespace predicates:
|
||||
|
||||
```
|
||||
TG_SECTION_TYPE = "https://trustgraph.ai/ns/Section"
|
||||
TG_IMAGE_TYPE = "https://trustgraph.ai/ns/Image"
|
||||
TG_ELEMENT_TYPES = "https://trustgraph.ai/ns/elementTypes"
|
||||
TG_TABLE_COUNT = "https://trustgraph.ai/ns/tableCount"
|
||||
TG_IMAGE_COUNT = "https://trustgraph.ai/ns/imageCount"
|
||||
```
|
||||
|
||||
Image URN scheme: `urn:image:{uuid}`
|
||||
|
||||
(`TG_MIME_TYPE` already exists.)
|
||||
|
||||
#### New entity type
|
||||
|
||||
For non-page formats (DOCX, HTML, Markdown, etc.) where the decoder
|
||||
emits the whole document as a single unit rather than splitting by
|
||||
page, the entity gets a new type:
|
||||
|
||||
```
|
||||
TG_SECTION_TYPE = "https://trustgraph.ai/ns/Section"
|
||||
```
|
||||
|
||||
This distinguishes sections from pages when querying provenance:
|
||||
|
||||
| Entity | Type | When used |
|
||||
|----------|-----------------------------|----------------------------------------|
|
||||
| Document | `tg:Document` | Original uploaded file |
|
||||
| Page | `tg:Page` | Page-based formats (PDF, PPTX, XLSX) |
|
||||
| Section | `tg:Section` | Non-page formats (DOCX, HTML, MD, etc) |
|
||||
| Image | `tg:Image` | Embedded images (stored, not processed)|
|
||||
| Chunk | `tg:Chunk` | Output of chunker |
|
||||
| Subgraph | `tg:Subgraph` | KG extraction output |
|
||||
|
||||
The type is set by the decoder based on whether it's grouping by page
|
||||
or emitting a whole-document section. `derived_entity_triples` gains
|
||||
an optional `section` boolean parameter — when true, the entity is
|
||||
typed as `tg:Section` instead of `tg:Page`.
|
||||
|
||||
#### Full provenance chain
|
||||
|
||||
```
|
||||
KG triple
|
||||
→ subgraph (extraction provenance)
|
||||
→ chunk (char_offset, char_length within page)
|
||||
→ page/section (page_number, char_offset, char_length within doc, mime_type, element_types)
|
||||
→ document (original file in librarian)
|
||||
```
|
||||
|
||||
Every link is a set of triples in the `GRAPH_SOURCE` named graph.
|
||||
|
||||
### Service Configuration
|
||||
|
||||
Command-line arguments:
|
||||
|
||||
```
|
||||
--strategy Partitioning strategy: auto, hi_res, fast (default: auto)
|
||||
--languages Comma-separated OCR language codes (default: eng)
|
||||
--section-strategy Section grouping: whole-document, heading, element-type,
|
||||
count, size (default: whole-document)
|
||||
--section-element-count Elements per section for 'count' strategy (default: 20)
|
||||
--section-max-size Max chars per section for 'size' strategy (default: 4000)
|
||||
--section-within-pages Apply section strategy within pages too (default: false)
|
||||
```
|
||||
|
||||
Plus the standard `FlowProcessor` and librarian queue arguments.
|
||||
|
||||
### Flow Integration
|
||||
|
||||
The universal decoder occupies the same position in the processing flow
|
||||
as the current PDF decoder:
|
||||
|
||||
```
|
||||
Document → [universal-decoder] → TextDocument → [chunker] → Chunk → ...
|
||||
```
|
||||
|
||||
It registers:
|
||||
- `input` consumer (Document schema)
|
||||
- `output` producer (TextDocument schema)
|
||||
- `triples` producer (Triples schema)
|
||||
- Librarian request/response (for fetch and child document storage)
|
||||
|
||||
### Deployment
|
||||
|
||||
- New container: `trustgraph-flow-universal-decoder`
|
||||
- Dependency: `unstructured[all-docs]` (includes PDF, DOCX, PPTX, etc.)
|
||||
- Can run alongside or replace the existing PDF decoder depending on
|
||||
flow configuration
|
||||
- The existing PDF decoder remains available for environments where
|
||||
`unstructured` dependencies are too heavy
|
||||
|
||||
### What Changes
|
||||
|
||||
| Component | Change |
|
||||
|------------------------------|-------------------------------------------------|
|
||||
| `provenance/namespaces.py` | Add `TG_SECTION_TYPE`, `TG_IMAGE_TYPE`, `TG_ELEMENT_TYPES`, `TG_TABLE_COUNT`, `TG_IMAGE_COUNT` |
|
||||
| `provenance/triples.py` | Add `mime_type`, `element_types`, `table_count`, `image_count` kwargs |
|
||||
| `provenance/__init__.py` | Export new constants |
|
||||
| New: `decoding/universal/` | New decoder service module |
|
||||
| `setup.cfg` / `pyproject` | Add `unstructured[all-docs]` dependency |
|
||||
| Docker | New container image |
|
||||
| Flow definitions | Wire universal-decoder as document input |
|
||||
|
||||
### What Doesn't Change
|
||||
|
||||
- Chunker (receives TextDocument, works as before)
|
||||
- Downstream extractors (receive Chunk, unchanged)
|
||||
- Librarian (stores child documents, unchanged)
|
||||
- Schema (Document, TextDocument, Chunk unchanged)
|
||||
- Query-time provenance (unchanged)
|
||||
|
||||
## Risks
|
||||
|
||||
- `unstructured[all-docs]` has heavy dependencies (poppler, tesseract,
|
||||
libreoffice for some formats). Container image will be larger.
|
||||
Mitigation: offer a `[light]` variant without OCR/office deps.
|
||||
- Some formats may produce poor text extraction (scanned PDFs without
|
||||
OCR, complex XLSX layouts). Mitigation: configurable `strategy`
|
||||
parameter, and the existing Mistral OCR decoder remains available
|
||||
for high-quality PDF OCR.
|
||||
- `unstructured` version updates may change element metadata.
|
||||
Mitigation: pin version, test extraction quality per format.
|
||||
|
|
@ -10,20 +10,30 @@ from unittest import IsolatedAsyncioTestCase
|
|||
from io import BytesIO
|
||||
|
||||
from trustgraph.decoding.mistral_ocr.processor import Processor
|
||||
from trustgraph.schema import Document, TextDocument, Metadata
|
||||
from trustgraph.schema import Document, TextDocument, Metadata, Triples
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
self.pubsub = MagicMock()
|
||||
self.taskgroup = params.get('taskgroup', MagicMock())
|
||||
|
||||
|
||||
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral OCR processor functionality"""
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_processor_initialization_with_api_key(self, mock_flow_init, mock_mistral_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_with_api_key(
|
||||
self, mock_mistral_class, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test Mistral OCR processor initialization with API key"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
mock_mistral = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral
|
||||
mock_mistral_class.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'id': 'test-mistral-ocr',
|
||||
|
|
@ -31,56 +41,39 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(Processor, 'register_specification') as mock_register:
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
mock_flow_init.assert_called_once()
|
||||
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
# Verify register_specification was called twice (consumer and producer)
|
||||
assert mock_register.call_count == 2
|
||||
# Check specs registered: input consumer, output producer, triples producer
|
||||
consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')]
|
||||
assert len(consumer_specs) >= 1
|
||||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
# Check consumer spec
|
||||
consumer_call = mock_register.call_args_list[0]
|
||||
consumer_spec = consumer_call[0][0]
|
||||
assert consumer_spec.name == "input"
|
||||
assert consumer_spec.schema == Document
|
||||
assert consumer_spec.handler == processor.on_message
|
||||
|
||||
# Check producer spec
|
||||
producer_call = mock_register.call_args_list[1]
|
||||
producer_spec = producer_call[0][0]
|
||||
assert producer_spec.name == "output"
|
||||
assert producer_spec.schema == TextDocument
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_flow_init):
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_without_api_key(
|
||||
self, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test Mistral OCR processor initialization without API key raises error"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-mistral-ocr',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
||||
processor = Processor(**config)
|
||||
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
||||
Processor(**config)
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_ocr_single_chunk(self, mock_flow_init, mock_mistral_class, mock_uuid):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_ocr_single_chunk(
|
||||
self, mock_mistral_class, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test OCR processing with a single chunk (less than 5 pages)"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
mock_uuid.return_value = "test-uuid-1234"
|
||||
|
||||
# Mock Mistral client
|
||||
mock_mistral = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral
|
||||
|
||||
|
|
@ -92,17 +85,21 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
mock_signed_url = MagicMock(url="https://example.com/signed-url")
|
||||
mock_mistral.files.get_signed_url.return_value = mock_signed_url
|
||||
|
||||
# Mock OCR response
|
||||
mock_page = MagicMock(
|
||||
# Mock OCR response with 2 pages
|
||||
mock_page1 = MagicMock(
|
||||
markdown="# Page 1\nContent ",
|
||||
images=[MagicMock(id="img1", image_base64="data:image/png;base64,abc123")]
|
||||
)
|
||||
mock_ocr_response = MagicMock(pages=[mock_page])
|
||||
mock_page2 = MagicMock(
|
||||
markdown="# Page 2\nMore content",
|
||||
images=[]
|
||||
)
|
||||
mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2])
|
||||
mock_mistral.ocr.process.return_value = mock_ocr_response
|
||||
|
||||
# Mock PyPDF
|
||||
mock_pdf_reader = MagicMock()
|
||||
mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()] # 3 pages
|
||||
mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()]
|
||||
|
||||
config = {
|
||||
'id': 'test-mistral-ocr',
|
||||
|
|
@ -110,58 +107,39 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader):
|
||||
with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class:
|
||||
mock_pdf_writer = MagicMock()
|
||||
mock_pdf_writer_class.return_value = mock_pdf_writer
|
||||
with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader):
|
||||
with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class:
|
||||
mock_pdf_writer = MagicMock()
|
||||
mock_pdf_writer_class.return_value = mock_pdf_writer
|
||||
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
result = processor.ocr(b"fake pdf content")
|
||||
|
||||
# Act
|
||||
result = processor.ocr(b"fake pdf content")
|
||||
# Returns list of (markdown, page_num) tuples
|
||||
assert len(result) == 2
|
||||
assert result[0] == ("# Page 1\nContent ", 1)
|
||||
assert result[1] == ("# Page 2\nMore content", 2)
|
||||
|
||||
# Assert
|
||||
assert result == "# Page 1\nContent "
|
||||
|
||||
# Verify PDF writer was used to create chunk
|
||||
# Verify PDF writer was used
|
||||
assert mock_pdf_writer.add_page.call_count == 3
|
||||
mock_pdf_writer.write_stream.assert_called_once()
|
||||
|
||||
# Verify Mistral API calls
|
||||
mock_mistral.files.upload.assert_called_once()
|
||||
upload_call = mock_mistral.files.upload.call_args[1]
|
||||
assert upload_call['file']['file_name'] == "test-uuid-1234"
|
||||
assert upload_call['purpose'] == 'ocr'
|
||||
|
||||
mock_mistral.files.get_signed_url.assert_called_once_with(
|
||||
file_id="file-123", expiry=1
|
||||
)
|
||||
mock_mistral.ocr.process.assert_called_once()
|
||||
|
||||
mock_mistral.ocr.process.assert_called_once_with(
|
||||
model="mistral-ocr-latest",
|
||||
include_image_base64=True,
|
||||
document={
|
||||
"type": "document_url",
|
||||
"document_url": "https://example.com/signed-url",
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_on_message_success(self, mock_flow_init, mock_mistral_class, mock_uuid):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(
|
||||
self, mock_mistral_class, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test successful message processing"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
mock_uuid.return_value = "test-uuid-5678"
|
||||
|
||||
# Mock Mistral client with simple OCR response
|
||||
mock_mistral = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral
|
||||
|
||||
# Mock the ocr method to return simple markdown
|
||||
ocr_result = "# Document Title\nThis is the OCR content"
|
||||
mock_mistral_class.return_value = MagicMock()
|
||||
|
||||
# Mock message
|
||||
pdf_content = b"fake pdf content"
|
||||
|
|
@ -171,9 +149,13 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
# Mock flow - needs to be a callable that returns an object with send method
|
||||
# Mock flow
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_flow = MagicMock(return_value=mock_output_flow)
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
config = {
|
||||
'id': 'test-mistral-ocr',
|
||||
|
|
@ -181,47 +163,49 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock the ocr method
|
||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||
# Act
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
# Mock ocr to return per-page results
|
||||
ocr_result = [
|
||||
("# Page 1\nContent", 1),
|
||||
("# Page 2\nMore content", 2),
|
||||
]
|
||||
|
||||
# Assert
|
||||
# Verify output was sent
|
||||
mock_output_flow.send.assert_called_once()
|
||||
# Mock save_child_document
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Check output
|
||||
call_args = mock_output_flow.send.call_args[0][0]
|
||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Verify output was sent for each page
|
||||
assert mock_output_flow.send.call_count == 2
|
||||
# Verify triples were sent for each page
|
||||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
# Check output uses UUID-based page URNs
|
||||
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
||||
assert isinstance(call_args, TextDocument)
|
||||
assert call_args.metadata == mock_metadata
|
||||
assert call_args.text == ocr_result.encode('utf-8')
|
||||
assert call_args.document_id.startswith("urn:page:")
|
||||
assert call_args.text == b"" # Content stored in librarian
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_chunks_function(self, mock_flow_init, mock_mistral_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunks_function(self, mock_mistral_class):
|
||||
"""Test the chunks utility function"""
|
||||
# Arrange
|
||||
from trustgraph.decoding.mistral_ocr.processor import chunks
|
||||
|
||||
test_list = list(range(12))
|
||||
|
||||
# Act
|
||||
result = list(chunks(test_list, 5))
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result[0] == [0, 1, 2, 3, 4]
|
||||
assert result[1] == [5, 6, 7, 8, 9]
|
||||
assert result[2] == [10, 11]
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_replace_images_in_markdown(self, mock_flow_init, mock_mistral_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_replace_images_in_markdown(self, mock_mistral_class):
|
||||
"""Test the replace_images_in_markdown function"""
|
||||
# Arrange
|
||||
from trustgraph.decoding.mistral_ocr.processor import replace_images_in_markdown
|
||||
|
||||
markdown = "# Title\n\nSome text\n"
|
||||
|
|
@ -230,66 +214,34 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
"image2": "data:image/png;base64,def456"
|
||||
}
|
||||
|
||||
# Act
|
||||
result = replace_images_in_markdown(markdown, images_dict)
|
||||
|
||||
# Assert
|
||||
expected = "# Title\n\nSome text\n"
|
||||
assert result == expected
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_get_combined_markdown(self, mock_flow_init, mock_mistral_class):
|
||||
"""Test the get_combined_markdown function"""
|
||||
# Arrange
|
||||
from trustgraph.decoding.mistral_ocr.processor import get_combined_markdown
|
||||
from mistralai.models import OCRResponse
|
||||
|
||||
# Mock OCR response with multiple pages
|
||||
mock_page1 = MagicMock(
|
||||
markdown="# Page 1\n",
|
||||
images=[MagicMock(id="img1", image_base64="base64_img1")]
|
||||
)
|
||||
mock_page2 = MagicMock(
|
||||
markdown="# Page 2\n",
|
||||
images=[MagicMock(id="img2", image_base64="base64_img2")]
|
||||
)
|
||||
mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2])
|
||||
|
||||
# Act
|
||||
result = get_combined_markdown(mock_ocr_response)
|
||||
|
||||
# Assert
|
||||
expected = "# Page 1\n\n\n# Page 2\n"
|
||||
assert result == expected
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
def test_add_args(self, mock_parent_add_args):
|
||||
"""Test add_args adds API key argument"""
|
||||
# Arrange
|
||||
"""Test add_args adds expected arguments"""
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
mock_parser.add_argument.assert_called_once_with(
|
||||
'-k', '--api-key',
|
||||
default=None, # default_api_key is None in test environment
|
||||
help='Mistral API Key'
|
||||
)
|
||||
assert mock_parser.add_argument.call_count == 3
|
||||
# Check the API key arg is among them
|
||||
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
|
||||
assert ('-k', '--api-key') in call_args_list
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Processor.launch')
|
||||
def test_run(self, mock_launch):
|
||||
"""Test run function"""
|
||||
# Act
|
||||
from trustgraph.decoding.mistral_ocr.processor import run
|
||||
run()
|
||||
|
||||
# Assert
|
||||
mock_launch.assert_called_once_with("pdf-decoder",
|
||||
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n")
|
||||
mock_launch.assert_called_once()
|
||||
args = mock_launch.call_args[0]
|
||||
assert args[0] == "document-decoder"
|
||||
assert "Mistral OCR decoder" in args[1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -171,8 +171,8 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_output_flow.send.assert_called_once()
|
||||
call_args = mock_output_flow.send.call_args[0][0]
|
||||
# PDF decoder now forwards document_id, chunker fetches content from librarian
|
||||
assert call_args.document_id == "test-doc/p1"
|
||||
# PDF decoder now forwards document_id with UUID-based URN
|
||||
assert call_args.document_id.startswith("urn:page:")
|
||||
assert call_args.text == b"" # Content stored in librarian, not inline
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
|
|
@ -187,7 +187,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
"""Test run function"""
|
||||
from trustgraph.decoding.pdf.pdf_decoder import run
|
||||
run()
|
||||
mock_launch.assert_called_once_with("pdf-decoder",
|
||||
mock_launch.assert_called_once_with("document-decoder",
|
||||
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n\nSupports both inline document data and fetching from librarian via Pulsar\nfor large documents.\n")
|
||||
|
||||
|
||||
|
|
|
|||
412
tests/unit/test_decoding/test_universal_processor.py
Normal file
412
tests/unit/test_decoding/test_universal_processor.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
"""
|
||||
Unit tests for trustgraph.decoding.universal.processor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.decoding.universal.processor import (
|
||||
Processor, assemble_section_text, MIME_EXTENSIONS, PAGE_BASED_FORMATS,
|
||||
)
|
||||
from trustgraph.schema import Document, TextDocument, Metadata, Triples
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
self.pubsub = MagicMock()
|
||||
self.taskgroup = params.get('taskgroup', MagicMock())
|
||||
|
||||
|
||||
def make_element(category="NarrativeText", text="Some text",
|
||||
page_number=None, text_as_html=None, image_base64=None):
|
||||
"""Create a mock unstructured element."""
|
||||
el = MagicMock()
|
||||
el.category = category
|
||||
el.text = text
|
||||
el.metadata = MagicMock()
|
||||
el.metadata.page_number = page_number
|
||||
el.metadata.text_as_html = text_as_html
|
||||
el.metadata.image_base64 = image_base64
|
||||
return el
|
||||
|
||||
|
||||
class TestAssembleSectionText:
|
||||
"""Test the text assembly function."""
|
||||
|
||||
def test_narrative_text(self):
|
||||
elements = [
|
||||
make_element("NarrativeText", "Paragraph one."),
|
||||
make_element("NarrativeText", "Paragraph two."),
|
||||
]
|
||||
text, types, tables, images = assemble_section_text(elements)
|
||||
assert text == "Paragraph one.\n\nParagraph two."
|
||||
assert "NarrativeText" in types
|
||||
assert tables == 0
|
||||
assert images == 0
|
||||
|
||||
def test_table_with_html(self):
|
||||
elements = [
|
||||
make_element("NarrativeText", "Before table."),
|
||||
make_element(
|
||||
"Table", "Col1 Col2",
|
||||
text_as_html="<table><tr><td>Col1</td><td>Col2</td></tr></table>"
|
||||
),
|
||||
]
|
||||
text, types, tables, images = assemble_section_text(elements)
|
||||
assert "<table>" in text
|
||||
assert "Before table." in text
|
||||
assert tables == 1
|
||||
assert "Table" in types
|
||||
|
||||
def test_table_without_html_fallback(self):
|
||||
el = make_element("Table", "plain table text")
|
||||
el.metadata.text_as_html = None
|
||||
elements = [el]
|
||||
text, types, tables, images = assemble_section_text(elements)
|
||||
assert text == "plain table text"
|
||||
assert tables == 1
|
||||
|
||||
def test_images_skipped(self):
|
||||
elements = [
|
||||
make_element("NarrativeText", "Text content"),
|
||||
make_element("Image", "OCR text from image"),
|
||||
]
|
||||
text, types, tables, images = assemble_section_text(elements)
|
||||
assert "OCR text" not in text
|
||||
assert "Text content" in text
|
||||
assert images == 1
|
||||
assert "Image" in types
|
||||
|
||||
def test_empty_elements(self):
|
||||
text, types, tables, images = assemble_section_text([])
|
||||
assert text == ""
|
||||
assert len(types) == 0
|
||||
assert tables == 0
|
||||
assert images == 0
|
||||
|
||||
def test_mixed_elements(self):
|
||||
elements = [
|
||||
make_element("Title", "Section Heading"),
|
||||
make_element("NarrativeText", "Body text."),
|
||||
make_element(
|
||||
"Table", "data",
|
||||
text_as_html="<table><tr><td>data</td></tr></table>"
|
||||
),
|
||||
make_element("Image", "img text"),
|
||||
make_element("ListItem", "- item one"),
|
||||
]
|
||||
text, types, tables, images = assemble_section_text(elements)
|
||||
assert "Section Heading" in text
|
||||
assert "Body text." in text
|
||||
assert "<table>" in text
|
||||
assert "img text" not in text
|
||||
assert "- item one" in text
|
||||
assert tables == 1
|
||||
assert images == 1
|
||||
assert {"Title", "NarrativeText", "Table", "Image", "ListItem"} == types
|
||||
|
||||
|
||||
class TestMimeExtensions:
|
||||
"""Test the mime type to extension mapping."""
|
||||
|
||||
def test_pdf_extension(self):
|
||||
assert MIME_EXTENSIONS["application/pdf"] == ".pdf"
|
||||
|
||||
def test_docx_extension(self):
|
||||
key = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
assert MIME_EXTENSIONS[key] == ".docx"
|
||||
|
||||
def test_html_extension(self):
|
||||
assert MIME_EXTENSIONS["text/html"] == ".html"
|
||||
|
||||
|
||||
class TestPageBasedFormats:
|
||||
"""Test page-based format detection."""
|
||||
|
||||
def test_pdf_is_page_based(self):
|
||||
assert "application/pdf" in PAGE_BASED_FORMATS
|
||||
|
||||
def test_html_is_not_page_based(self):
|
||||
assert "text/html" not in PAGE_BASED_FORMATS
|
||||
|
||||
def test_pptx_is_page_based(self):
|
||||
pptx = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
assert pptx in PAGE_BASED_FORMATS
|
||||
|
||||
|
||||
class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test universal decoder processor."""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(
|
||||
self, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test processor initialization with defaults."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert processor.partition_strategy == "auto"
|
||||
assert processor.section_strategy_name == "whole-document"
|
||||
assert processor.section_element_count == 20
|
||||
assert processor.section_max_size == 4000
|
||||
|
||||
# Check specs: input consumer, output producer, triples producer
|
||||
consumer_specs = [
|
||||
s for s in processor.specifications if hasattr(s, 'handler')
|
||||
]
|
||||
assert len(consumer_specs) >= 1
|
||||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_custom_strategy(
|
||||
self, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test processor initialization with custom section strategy."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
'section_strategy': 'heading',
|
||||
'strategy': 'hi_res',
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert processor.partition_strategy == "hi_res"
|
||||
assert processor.section_strategy_name == "heading"
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_group_by_page(self, mock_producer, mock_consumer):
|
||||
"""Test page grouping of elements."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
elements = [
|
||||
make_element("NarrativeText", "Page 1 text", page_number=1),
|
||||
make_element("NarrativeText", "More page 1", page_number=1),
|
||||
make_element("NarrativeText", "Page 2 text", page_number=2),
|
||||
]
|
||||
|
||||
result = processor.group_by_page(elements)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0][0] == 1 # page number
|
||||
assert len(result[0][1]) == 2 # 2 elements on page 1
|
||||
assert result[1][0] == 2
|
||||
assert len(result[1][1]) == 1
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_inline_non_page(
|
||||
self, mock_partition, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test processing an inline non-page document."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock partition to return elements without page numbers
|
||||
mock_partition.return_value = [
|
||||
make_element("Title", "Document Title"),
|
||||
make_element("NarrativeText", "Body text content."),
|
||||
]
|
||||
|
||||
# Mock message with inline data
|
||||
content = b"# Document Title\nBody text content."
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
data=base64.b64encode(content).decode('utf-8'),
|
||||
)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
# Mock flow
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
# Mock save_child_document and magic
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "text/markdown"
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Should emit one section (whole-document strategy)
|
||||
assert mock_output_flow.send.call_count == 1
|
||||
assert mock_triples_flow.send.call_count == 1
|
||||
|
||||
# Check output
|
||||
call_args = mock_output_flow.send.call_args[0][0]
|
||||
assert isinstance(call_args, TextDocument)
|
||||
assert call_args.document_id.startswith("urn:section:")
|
||||
assert call_args.text == b""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_page_based(
|
||||
self, mock_partition, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test processing a page-based document."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock partition to return elements with page numbers
|
||||
mock_partition.return_value = [
|
||||
make_element("NarrativeText", "Page 1 content", page_number=1),
|
||||
make_element("NarrativeText", "Page 2 content", page_number=2),
|
||||
]
|
||||
|
||||
# Mock message
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
data=base64.b64encode(content).decode('utf-8'),
|
||||
)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Should emit two pages
|
||||
assert mock_output_flow.send.call_count == 2
|
||||
|
||||
# Check first output uses page URI
|
||||
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
||||
assert call_args.document_id.startswith("urn:page:")
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_images_stored_not_emitted(
|
||||
self, mock_partition, mock_producer, mock_consumer
|
||||
):
|
||||
"""Test that images are stored but not sent to text pipeline."""
|
||||
config = {
|
||||
'id': 'test-universal',
|
||||
'taskgroup': AsyncMock(),
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_partition.return_value = [
|
||||
make_element("NarrativeText", "Some text", page_number=1),
|
||||
make_element("Image", "img ocr", page_number=1,
|
||||
image_base64="aW1hZ2VkYXRh"),
|
||||
]
|
||||
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
data=base64.b64encode(content).decode('utf-8'),
|
||||
)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Only 1 TextDocument output (the page text, not the image)
|
||||
assert mock_output_flow.send.call_count == 1
|
||||
|
||||
# But 2 triples outputs (page provenance + image provenance)
|
||||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
# save_child_document called twice (page + image)
|
||||
assert processor.save_child_document.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
def test_add_args(self, mock_parent_add_args):
|
||||
"""Test add_args registers all expected arguments."""
|
||||
mock_parser = MagicMock()
|
||||
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Check key arguments are registered
|
||||
arg_names = [
|
||||
c[0] for c in mock_parser.add_argument.call_args_list
|
||||
]
|
||||
assert ('--strategy',) in arg_names
|
||||
assert ('--languages',) in arg_names
|
||||
assert ('--section-strategy',) in arg_names
|
||||
assert ('--section-element-count',) in arg_names
|
||||
assert ('--section-max-size',) in arg_names
|
||||
assert ('--section-within-pages',) in arg_names
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Processor.launch')
|
||||
def test_run(self, mock_launch):
|
||||
"""Test run function."""
|
||||
from trustgraph.decoding.universal.processor import run
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once()
|
||||
args = mock_launch.call_args[0]
|
||||
assert args[0] == "document-decoder"
|
||||
assert "Universal document decoder" in args[1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
204
tests/unit/test_decoding/test_universal_strategies.py
Normal file
204
tests/unit/test_decoding/test_universal_strategies.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
Unit tests for universal decoder section grouping strategies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
from trustgraph.decoding.universal.strategies import (
|
||||
group_whole_document,
|
||||
group_by_heading,
|
||||
group_by_element_type,
|
||||
group_by_count,
|
||||
group_by_size,
|
||||
get_strategy,
|
||||
STRATEGIES,
|
||||
)
|
||||
|
||||
|
||||
def make_element(category="NarrativeText", text="Some text"):
|
||||
"""Create a mock unstructured element."""
|
||||
el = MagicMock()
|
||||
el.category = category
|
||||
el.text = text
|
||||
return el
|
||||
|
||||
|
||||
class TestGroupWholeDocument:
|
||||
|
||||
def test_empty_input(self):
|
||||
assert group_whole_document([]) == []
|
||||
|
||||
def test_returns_single_group(self):
|
||||
elements = [make_element() for _ in range(5)]
|
||||
result = group_whole_document(elements)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 5
|
||||
|
||||
def test_preserves_all_elements(self):
|
||||
elements = [make_element(text=f"text-{i}") for i in range(3)]
|
||||
result = group_whole_document(elements)
|
||||
assert result[0] == elements
|
||||
|
||||
|
||||
class TestGroupByHeading:
|
||||
|
||||
def test_empty_input(self):
|
||||
assert group_by_heading([]) == []
|
||||
|
||||
def test_no_headings_falls_back(self):
|
||||
elements = [make_element("NarrativeText") for _ in range(3)]
|
||||
result = group_by_heading(elements)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 3
|
||||
|
||||
def test_splits_at_headings(self):
|
||||
elements = [
|
||||
make_element("Title", "Heading 1"),
|
||||
make_element("NarrativeText", "Paragraph 1"),
|
||||
make_element("NarrativeText", "Paragraph 2"),
|
||||
make_element("Title", "Heading 2"),
|
||||
make_element("NarrativeText", "Paragraph 3"),
|
||||
]
|
||||
result = group_by_heading(elements)
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 3 # Heading 1 + 2 paragraphs
|
||||
assert len(result[1]) == 2 # Heading 2 + 1 paragraph
|
||||
|
||||
def test_leading_content_before_first_heading(self):
|
||||
elements = [
|
||||
make_element("NarrativeText", "Preamble"),
|
||||
make_element("Title", "Heading 1"),
|
||||
make_element("NarrativeText", "Content"),
|
||||
]
|
||||
result = group_by_heading(elements)
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 1 # Preamble
|
||||
assert len(result[1]) == 2 # Heading + content
|
||||
|
||||
def test_consecutive_headings(self):
|
||||
elements = [
|
||||
make_element("Title", "H1"),
|
||||
make_element("Title", "H2"),
|
||||
make_element("NarrativeText", "Content"),
|
||||
]
|
||||
result = group_by_heading(elements)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestGroupByElementType:
|
||||
|
||||
def test_empty_input(self):
|
||||
assert group_by_element_type([]) == []
|
||||
|
||||
def test_all_same_type(self):
|
||||
elements = [make_element("NarrativeText") for _ in range(3)]
|
||||
result = group_by_element_type(elements)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_splits_at_table_boundary(self):
|
||||
elements = [
|
||||
make_element("NarrativeText", "Intro"),
|
||||
make_element("NarrativeText", "More text"),
|
||||
make_element("Table", "Table data"),
|
||||
make_element("NarrativeText", "After table"),
|
||||
]
|
||||
result = group_by_element_type(elements)
|
||||
assert len(result) == 3
|
||||
assert len(result[0]) == 2 # Two narrative elements
|
||||
assert len(result[1]) == 1 # One table
|
||||
assert len(result[2]) == 1 # One narrative
|
||||
|
||||
def test_consecutive_tables_stay_grouped(self):
|
||||
elements = [
|
||||
make_element("Table", "Table 1"),
|
||||
make_element("Table", "Table 2"),
|
||||
]
|
||||
result = group_by_element_type(elements)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 2
|
||||
|
||||
|
||||
class TestGroupByCount:
|
||||
|
||||
def test_empty_input(self):
|
||||
assert group_by_count([]) == []
|
||||
|
||||
def test_exact_multiple(self):
|
||||
elements = [make_element() for _ in range(6)]
|
||||
result = group_by_count(elements, element_count=3)
|
||||
assert len(result) == 2
|
||||
assert all(len(g) == 3 for g in result)
|
||||
|
||||
def test_remainder_group(self):
|
||||
elements = [make_element() for _ in range(7)]
|
||||
result = group_by_count(elements, element_count=3)
|
||||
assert len(result) == 3
|
||||
assert len(result[0]) == 3
|
||||
assert len(result[1]) == 3
|
||||
assert len(result[2]) == 1
|
||||
|
||||
def test_fewer_than_count(self):
|
||||
elements = [make_element() for _ in range(2)]
|
||||
result = group_by_count(elements, element_count=10)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 2
|
||||
|
||||
|
||||
class TestGroupBySize:
|
||||
|
||||
def test_empty_input(self):
|
||||
assert group_by_size([]) == []
|
||||
|
||||
def test_small_elements_grouped(self):
|
||||
elements = [make_element(text="Hi") for _ in range(5)]
|
||||
result = group_by_size(elements, max_size=100)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_splits_at_size_limit(self):
|
||||
elements = [make_element(text="x" * 100) for _ in range(5)]
|
||||
result = group_by_size(elements, max_size=250)
|
||||
# 2 elements per group (200 chars), then split
|
||||
assert len(result) == 3
|
||||
assert len(result[0]) == 2
|
||||
assert len(result[1]) == 2
|
||||
assert len(result[2]) == 1
|
||||
|
||||
def test_large_element_own_group(self):
|
||||
elements = [
|
||||
make_element(text="small"),
|
||||
make_element(text="x" * 5000), # Exceeds max
|
||||
make_element(text="small"),
|
||||
]
|
||||
result = group_by_size(elements, max_size=100)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_respects_element_boundaries(self):
|
||||
# Each element is 50 chars, max is 120
|
||||
# Should get 2 per group, not split mid-element
|
||||
elements = [make_element(text="x" * 50) for _ in range(5)]
|
||||
result = group_by_size(elements, max_size=120)
|
||||
assert len(result) == 3
|
||||
assert len(result[0]) == 2
|
||||
assert len(result[1]) == 2
|
||||
assert len(result[2]) == 1
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
|
||||
def test_all_strategies_accessible(self):
|
||||
for name in STRATEGIES:
|
||||
fn = get_strategy(name)
|
||||
assert callable(fn)
|
||||
|
||||
def test_unknown_strategy_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown section strategy"):
|
||||
get_strategy("nonexistent")
|
||||
|
||||
def test_returns_correct_function(self):
|
||||
assert get_strategy("whole-document") is group_whole_document
|
||||
assert get_strategy("heading") is group_by_heading
|
||||
assert get_strategy("element-type") is group_by_element_type
|
||||
assert get_strategy("count") is group_by_count
|
||||
assert get_strategy("size") is group_by_size
|
||||
|
|
@ -121,11 +121,11 @@ class TestBeginUpload:
|
|||
assert resp.total_chunks == math.ceil(10_000 / 3000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_invalid_kind(self):
|
||||
async def test_rejects_empty_kind(self):
|
||||
lib = _make_librarian()
|
||||
req = _make_begin_request(kind="image/png")
|
||||
req = _make_begin_request(kind="")
|
||||
|
||||
with pytest.raises(RequestError, match="Invalid document kind"):
|
||||
with pytest.raises(RequestError, match="MIME type.*required"):
|
||||
await lib.begin_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ from trustgraph.provenance.uris import (
|
|||
_encode_id,
|
||||
document_uri,
|
||||
page_uri,
|
||||
chunk_uri_from_page,
|
||||
chunk_uri_from_doc,
|
||||
chunk_uri,
|
||||
activity_uri,
|
||||
subgraph_uri,
|
||||
agent_uri,
|
||||
|
|
@ -60,31 +59,22 @@ class TestDocumentUris:
|
|||
assert document_uri(iri) == iri
|
||||
|
||||
def test_page_uri_format(self):
|
||||
result = page_uri("https://example.com/doc/123", 5)
|
||||
assert result == "https://example.com/doc/123/p5"
|
||||
result = page_uri()
|
||||
assert result.startswith("urn:page:")
|
||||
|
||||
def test_page_uri_page_zero(self):
|
||||
result = page_uri("https://example.com/doc/123", 0)
|
||||
assert result == "https://example.com/doc/123/p0"
|
||||
def test_page_uri_unique(self):
|
||||
r1 = page_uri()
|
||||
r2 = page_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_chunk_uri_from_page_format(self):
|
||||
result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
|
||||
assert result == "https://example.com/doc/123/p2/c3"
|
||||
def test_chunk_uri_format(self):
|
||||
result = chunk_uri()
|
||||
assert result.startswith("urn:chunk:")
|
||||
|
||||
def test_chunk_uri_from_doc_format(self):
|
||||
result = chunk_uri_from_doc("https://example.com/doc/123", 7)
|
||||
assert result == "https://example.com/doc/123/c7"
|
||||
|
||||
def test_page_uri_preserves_doc_iri(self):
|
||||
doc = "urn:isbn:978-3-16-148410-0"
|
||||
result = page_uri(doc, 1)
|
||||
assert result.startswith(doc)
|
||||
|
||||
def test_chunk_from_page_hierarchy(self):
|
||||
"""Chunk URI should contain both page and chunk identifiers."""
|
||||
result = chunk_uri_from_page("https://example.com/doc", 3, 5)
|
||||
assert "/p3/" in result
|
||||
assert result.endswith("/c5")
|
||||
def test_chunk_uri_unique(self):
|
||||
r1 = chunk_uri()
|
||||
r2 = chunk_uri()
|
||||
assert r1 != r2
|
||||
|
||||
|
||||
class TestActivityAndSubgraphUris:
|
||||
|
|
|
|||
|
|
@ -449,7 +449,7 @@ class FlowInstance:
|
|||
def graph_rag(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
max_path_length=2, edge_score_limit=30, edge_limit=25,
|
||||
):
|
||||
"""
|
||||
Execute graph-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -465,6 +465,8 @@ class FlowInstance:
|
|||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 150)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating graph context
|
||||
|
|
@ -492,6 +494,8 @@ class FlowInstance:
|
|||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -699,9 +699,12 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
entity_limit: int = 50,
|
||||
triple_limit: int = 30,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
|
|
@ -715,9 +718,12 @@ class SocketFlowInstance:
|
|||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Maximum entities to retrieve (default: 50)
|
||||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
|
||||
max_subgraph_count: Maximum number of subgraphs (default: 5)
|
||||
max_entity_distance: Maximum traversal depth (default: 3)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
streaming: Enable streaming mode (default: False)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
|
|
@ -743,9 +749,12 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -762,9 +771,12 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
entity_limit: int = 50,
|
||||
triple_limit: int = 30,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""
|
||||
|
|
@ -778,9 +790,12 @@ class SocketFlowInstance:
|
|||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Maximum entities to retrieve (default: 50)
|
||||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
|
||||
max_subgraph_count: Maximum number of subgraphs (default: 5)
|
||||
max_entity_distance: Maximum traversal depth (default: 3)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
|
|
@ -823,11 +838,14 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"streaming": True,
|
||||
"explainable": True, # Enable explainability mode
|
||||
"explainable": True,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -96,8 +96,6 @@ class ChunkingService(FlowProcessor):
|
|||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
triple_limit=int(data.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(data.get("max-path-length", 2)),
|
||||
edge_score_limit=int(data.get("edge-score-limit", 30)),
|
||||
edge_limit=int(data.get("edge-limit", 25)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
|
@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
"triple-limit": obj.triple_limit,
|
||||
"max-subgraph-size": obj.max_subgraph_size,
|
||||
"max-path-length": obj.max_path_length,
|
||||
"edge-score-limit": obj.edge_score_limit,
|
||||
"edge-limit": obj.edge_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ Provides helpers for:
|
|||
Usage example:
|
||||
|
||||
from trustgraph.provenance import (
|
||||
document_uri, page_uri, chunk_uri_from_page,
|
||||
document_uri, page_uri, chunk_uri,
|
||||
document_triples, derived_entity_triples,
|
||||
get_vocabulary_triples,
|
||||
)
|
||||
|
||||
# Generate URIs
|
||||
doc_uri = document_uri("my-doc-123")
|
||||
page_uri = page_uri("my-doc-123", page_number=1)
|
||||
pg_uri = page_uri()
|
||||
|
||||
# Build provenance triples
|
||||
triples = document_triples(
|
||||
|
|
@ -35,8 +35,9 @@ from . uris import (
|
|||
TRUSTGRAPH_BASE,
|
||||
document_uri,
|
||||
page_uri,
|
||||
chunk_uri_from_page,
|
||||
chunk_uri_from_doc,
|
||||
section_uri,
|
||||
chunk_uri,
|
||||
image_uri,
|
||||
activity_uri,
|
||||
subgraph_uri,
|
||||
agent_uri,
|
||||
|
|
@ -75,8 +76,10 @@ from . namespaces import (
|
|||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
|
||||
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
|
||||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_SECTION_TYPE, TG_CHUNK_TYPE,
|
||||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
|
|
@ -138,8 +141,9 @@ __all__ = [
|
|||
"TRUSTGRAPH_BASE",
|
||||
"document_uri",
|
||||
"page_uri",
|
||||
"chunk_uri_from_page",
|
||||
"chunk_uri_from_doc",
|
||||
"section_uri",
|
||||
"chunk_uri",
|
||||
"image_uri",
|
||||
"activity_uri",
|
||||
"subgraph_uri",
|
||||
"agent_uri",
|
||||
|
|
@ -171,8 +175,10 @@ __all__ = [
|
|||
"TG_CHUNK_SIZE", "TG_CHUNK_OVERLAP", "TG_COMPONENT_VERSION",
|
||||
"TG_LLM_MODEL", "TG_ONTOLOGY", "TG_EMBEDDING_MODEL",
|
||||
"TG_SOURCE_TEXT", "TG_SOURCE_CHAR_OFFSET", "TG_SOURCE_CHAR_LENGTH",
|
||||
"TG_ELEMENT_TYPES", "TG_TABLE_COUNT", "TG_IMAGE_COUNT",
|
||||
# Extraction provenance entity types
|
||||
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_SECTION_TYPE",
|
||||
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
|
|
|
|||
|
|
@ -75,9 +75,16 @@ TG_SELECTED_CHUNK = TG + "selectedChunk"
|
|||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE = TG + "Document"
|
||||
TG_PAGE_TYPE = TG + "Page"
|
||||
TG_SECTION_TYPE = TG + "Section"
|
||||
TG_CHUNK_TYPE = TG + "Chunk"
|
||||
TG_IMAGE_TYPE = TG + "Image"
|
||||
TG_SUBGRAPH_TYPE = TG + "Subgraph"
|
||||
|
||||
# Universal decoder metadata predicates
|
||||
TG_ELEMENT_TYPES = TG + "elementTypes"
|
||||
TG_TABLE_COUNT = TG + "tableCount"
|
||||
TG_IMAGE_COUNT = TG + "imageCount"
|
||||
|
||||
# Explainability entity types (shared)
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_GROUNDING = TG + "Grounding"
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ from . namespaces import (
|
|||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_SECTION_TYPE, TG_CHUNK_TYPE,
|
||||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Universal decoder metadata predicates
|
||||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
|
|
@ -129,15 +132,22 @@ def derived_entity_triples(
|
|||
component_version: str,
|
||||
label: Optional[str] = None,
|
||||
page_number: Optional[int] = None,
|
||||
section: bool = False,
|
||||
image: bool = False,
|
||||
chunk_index: Optional[int] = None,
|
||||
char_offset: Optional[int] = None,
|
||||
char_length: Optional[int] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
element_types: Optional[str] = None,
|
||||
table_count: Optional[int] = None,
|
||||
image_count: Optional[int] = None,
|
||||
timestamp: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a derived entity (page or chunk) with full PROV-O provenance.
|
||||
Build triples for a derived entity (page, section, chunk, or image)
|
||||
with full PROV-O provenance.
|
||||
|
||||
Creates:
|
||||
- Entity declaration
|
||||
|
|
@ -146,17 +156,23 @@ def derived_entity_triples(
|
|||
- Agent for the component
|
||||
|
||||
Args:
|
||||
entity_uri: URI of the derived entity (page or chunk)
|
||||
entity_uri: URI of the derived entity
|
||||
parent_uri: URI of the parent entity
|
||||
component_name: Name of TG component (e.g., "pdf-extractor", "chunker")
|
||||
component_version: Version of the component
|
||||
label: Human-readable label
|
||||
page_number: Page number (for pages)
|
||||
section: True if this is a document section (non-page format)
|
||||
image: True if this is an image entity
|
||||
chunk_index: Chunk index (for chunks)
|
||||
char_offset: Character offset in parent (for chunks)
|
||||
char_length: Character length (for chunks)
|
||||
char_offset: Character offset in parent
|
||||
char_length: Character length
|
||||
chunk_size: Configured chunk size (for chunking activity)
|
||||
chunk_overlap: Configured chunk overlap (for chunking activity)
|
||||
mime_type: Source document MIME type
|
||||
element_types: Comma-separated unstructured element categories
|
||||
table_count: Number of tables in this page/section
|
||||
image_count: Number of images in this page/section
|
||||
timestamp: ISO timestamp (defaults to now)
|
||||
|
||||
Returns:
|
||||
|
|
@ -169,7 +185,11 @@ def derived_entity_triples(
|
|||
agt_uri = agent_uri(component_name)
|
||||
|
||||
# Determine specific type from parameters
|
||||
if page_number is not None:
|
||||
if image:
|
||||
specific_type = TG_IMAGE_TYPE
|
||||
elif section:
|
||||
specific_type = TG_SECTION_TYPE
|
||||
elif page_number is not None:
|
||||
specific_type = TG_PAGE_TYPE
|
||||
elif chunk_index is not None:
|
||||
specific_type = TG_CHUNK_TYPE
|
||||
|
|
@ -225,6 +245,18 @@ def derived_entity_triples(
|
|||
if chunk_overlap is not None:
|
||||
triples.append(_triple(act_uri, TG_CHUNK_OVERLAP, _literal(chunk_overlap)))
|
||||
|
||||
if mime_type:
|
||||
triples.append(_triple(entity_uri, TG_MIME_TYPE, _literal(mime_type)))
|
||||
|
||||
if element_types:
|
||||
triples.append(_triple(entity_uri, TG_ELEMENT_TYPES, _literal(element_types)))
|
||||
|
||||
if table_count is not None:
|
||||
triples.append(_triple(entity_uri, TG_TABLE_COUNT, _literal(table_count)))
|
||||
|
||||
if image_count is not None:
|
||||
triples.append(_triple(entity_uri, TG_IMAGE_COUNT, _literal(image_count)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
"""
|
||||
URI generation for provenance entities.
|
||||
|
||||
Document IDs are already IRIs (e.g., https://trustgraph.ai/doc/abc123).
|
||||
Child entities (pages, chunks) append path segments to the parent IRI:
|
||||
- Document: {doc_iri} (as provided)
|
||||
- Page: {doc_iri}/p{page_number}
|
||||
- Chunk: {page_iri}/c{chunk_index} (from page)
|
||||
{doc_iri}/c{chunk_index} (from text doc)
|
||||
Document IDs are externally provided (e.g., https://trustgraph.ai/doc/abc123).
|
||||
Child entities (pages, chunks) use UUID-based URNs:
|
||||
- Document: {doc_iri} (as provided, not generated here)
|
||||
- Page: urn:page:{uuid}
|
||||
- Section: urn:section:{uuid}
|
||||
- Chunk: urn:chunk:{uuid}
|
||||
- Image: urn:image:{uuid}
|
||||
- Activity: https://trustgraph.ai/activity/{uuid}
|
||||
- Subgraph: https://trustgraph.ai/subgraph/{uuid}
|
||||
"""
|
||||
|
|
@ -28,19 +29,24 @@ def document_uri(doc_iri: str) -> str:
|
|||
return doc_iri
|
||||
|
||||
|
||||
def page_uri(doc_iri: str, page_number: int) -> str:
|
||||
"""Generate URI for a page by appending to document IRI."""
|
||||
return f"{doc_iri}/p{page_number}"
|
||||
def page_uri() -> str:
|
||||
"""Generate a unique URI for a page."""
|
||||
return f"urn:page:{uuid.uuid4()}"
|
||||
|
||||
|
||||
def chunk_uri_from_page(doc_iri: str, page_number: int, chunk_index: int) -> str:
|
||||
"""Generate URI for a chunk extracted from a page."""
|
||||
return f"{doc_iri}/p{page_number}/c{chunk_index}"
|
||||
def section_uri() -> str:
|
||||
"""Generate a unique URI for a document section."""
|
||||
return f"urn:section:{uuid.uuid4()}"
|
||||
|
||||
|
||||
def chunk_uri_from_doc(doc_iri: str, chunk_index: int) -> str:
|
||||
"""Generate URI for a chunk extracted directly from a text document."""
|
||||
return f"{doc_iri}/c{chunk_index}"
|
||||
def chunk_uri() -> str:
|
||||
"""Generate a unique URI for a chunk."""
|
||||
return f"urn:chunk:{uuid.uuid4()}"
|
||||
|
||||
|
||||
def image_uri() -> str:
|
||||
"""Generate a unique URI for an image."""
|
||||
return f"urn:image:{uuid.uuid4()}"
|
||||
|
||||
|
||||
def activity_uri(activity_id: str = None) -> str:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GraphRagQuery:
|
|||
triple_limit: int = 0
|
||||
max_subgraph_size: int = 0
|
||||
max_path_length: int = 0
|
||||
edge_score_limit: int = 0
|
||||
edge_limit: int = 0
|
||||
streaming: bool = False
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"boto3",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"requests",
|
||||
"pulsar-client",
|
||||
"aiohttp",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ default_entity_limit = 50
|
|||
default_triple_limit = 30
|
||||
default_max_subgraph_size = 150
|
||||
default_max_path_length = 2
|
||||
default_edge_score_limit = 30
|
||||
default_edge_limit = 25
|
||||
|
||||
# Provenance predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
|
|
@ -638,7 +640,8 @@ async def _question_explainable(
|
|||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
max_subgraph_size, max_path_length, edge_score_limit=30,
|
||||
edge_limit=25, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability using the new API classes."""
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -652,9 +655,12 @@ def _question_explainable_api(
|
|||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_subgraph_count=5,
|
||||
max_entity_distance=max_path_length,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -743,7 +749,8 @@ def _question_explainable_api(
|
|||
|
||||
def question(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=50,
|
||||
edge_limit=25, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
):
|
||||
|
||||
|
|
@ -759,6 +766,8 @@ def question(
|
|||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
|
|
@ -781,6 +790,8 @@ def question(
|
|||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -801,7 +812,9 @@ def question(
|
|||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
)
|
||||
print(resp)
|
||||
|
||||
|
|
@ -876,6 +889,20 @@ def main():
|
|||
help=f'Max path length (default: {default_max_path_length})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=default_edge_score_limit,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=default_edge_limit,
|
||||
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action='store_true',
|
||||
|
|
@ -908,6 +935,8 @@ def main():
|
|||
triple_limit=args.triple_limit,
|
||||
max_subgraph_size=args.max_subgraph_size,
|
||||
max_path_length=args.max_path_length,
|
||||
edge_score_limit=args.edge_score_limit,
|
||||
edge_limit=args.edge_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -214,6 +214,12 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None,
|
|||
chain_str = format_provenance_chain(chain)
|
||||
if chain_str:
|
||||
print(f" Source: {chain_str}")
|
||||
# Show content ID for the chunk (second item in chain)
|
||||
for item in chain:
|
||||
uri = item.get("uri", "")
|
||||
if uri.startswith("urn:chunk:"):
|
||||
print(f" Content: {uri}")
|
||||
break
|
||||
|
||||
print()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-flow>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"trustgraph-flow>=2.2,<2.3",
|
||||
"torch",
|
||||
"urllib3",
|
||||
"transformers",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"aiohttp",
|
||||
"anthropic",
|
||||
"scylla-driver",
|
||||
|
|
|
|||
|
|
@ -189,8 +189,6 @@ class Processor(AgentService):
|
|||
if request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
|
|||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
derived_entity_triples,
|
||||
chunk_uri as make_chunk_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
|
|
@ -124,10 +124,9 @@ class Processor(ChunkingService):
|
|||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
# Generate unique chunk ID
|
||||
c_uri = make_chunk_uri()
|
||||
chunk_doc_id = c_uri
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
|
|
@ -145,7 +144,7 @@ class Processor(ChunkingService):
|
|||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=chunk_uri,
|
||||
entity_uri=c_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
|
|
@ -159,7 +158,7 @@ class Processor(ChunkingService):
|
|||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
id=c_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
|
|
@ -170,7 +169,7 @@ class Processor(ChunkingService):
|
|||
# Forward chunk ID + content (post-chunker optimization)
|
||||
r = Chunk(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
id=c_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
|
|||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
derived_entity_triples,
|
||||
chunk_uri as make_chunk_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
|
|
@ -122,10 +122,9 @@ class Processor(ChunkingService):
|
|||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
# Generate unique chunk ID
|
||||
c_uri = make_chunk_uri()
|
||||
chunk_doc_id = c_uri
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
|
|
@ -143,7 +142,7 @@ class Processor(ChunkingService):
|
|||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=chunk_uri,
|
||||
entity_uri=c_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
|
|
@ -157,7 +156,7 @@ class Processor(ChunkingService):
|
|||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
id=c_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
|
|
@ -168,7 +167,7 @@ class Processor(ChunkingService):
|
|||
# Forward chunk ID + content (post-chunker optimization)
|
||||
r = Chunk(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
id=c_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
|
|
|
|||
|
|
@ -1,29 +1,48 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
Mistral OCR decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as markdown text as separate output objects.
|
||||
|
||||
Supports both inline document data and fetching from librarian via Pulsar
|
||||
for large documents.
|
||||
"""
|
||||
|
||||
from pypdf import PdfWriter, PdfReader
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from mistralai import Mistral
|
||||
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
|
||||
from mistralai.models import OCRResponse
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "pdf-decoder"
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "mistral-ocr-decoder"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
default_ident = "document-decoder"
|
||||
default_api_key = os.getenv("MISTRAL_TOKEN")
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
pages_per_chunk = 5
|
||||
|
||||
def chunks(lst, n):
|
||||
|
|
@ -48,27 +67,6 @@ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
|
|||
)
|
||||
return markdown_str
|
||||
|
||||
def get_combined_markdown(ocr_response: OCRResponse) -> str:
|
||||
"""
|
||||
Combine OCR text and images into a single markdown document.
|
||||
|
||||
Args:
|
||||
ocr_response: Response from OCR processing containing text and images
|
||||
|
||||
Returns:
|
||||
Combined markdown string with embedded images
|
||||
"""
|
||||
markdowns: list[str] = []
|
||||
# Extract images from page
|
||||
for page in ocr_response.pages:
|
||||
image_data = {}
|
||||
for img in page.images:
|
||||
image_data[img.id] = img.image_base64
|
||||
# Replace image placeholders with actual images
|
||||
markdowns.append(replace_images_in_markdown(page.markdown, image_data))
|
||||
|
||||
return "\n\n".join(markdowns)
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -97,6 +95,50 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor = id, flow = None, name = "librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend = self.pubsub,
|
||||
topic = librarian_request_q,
|
||||
schema = LibrarianRequest,
|
||||
metrics = librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor = id, flow = None, name = "librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
backend = self.pubsub,
|
||||
flow = None,
|
||||
topic = librarian_response_q,
|
||||
subscriber = f"{id}-librarian",
|
||||
schema = LibrarianResponse,
|
||||
handler = self.on_librarian_response,
|
||||
metrics = librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
if api_key is None:
|
||||
raise RuntimeError("Mistral API key not specified")
|
||||
|
||||
|
|
@ -107,13 +149,154 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info("Mistral OCR processor initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document metadata from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.document_metadata
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document content from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="page", title=None, timeout=120):
|
||||
"""
|
||||
Save a child document to the librarian.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
||||
|
||||
def ocr(self, blob):
|
||||
"""
|
||||
Run Mistral OCR on a PDF blob, returning per-page markdown strings.
|
||||
|
||||
Args:
|
||||
blob: Raw PDF bytes
|
||||
|
||||
Returns:
|
||||
List of (page_markdown, page_number) tuples, 1-indexed
|
||||
"""
|
||||
|
||||
logger.debug("Parse PDF...")
|
||||
|
||||
pdfbuf = BytesIO(blob)
|
||||
pdf = PdfReader(pdfbuf)
|
||||
|
||||
pages = []
|
||||
global_page_num = 0
|
||||
|
||||
for chunk in chunks(pdf.pages, pages_per_chunk):
|
||||
|
||||
logger.debug("Get next pages...")
|
||||
|
|
@ -152,11 +335,19 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Extract markdown...")
|
||||
|
||||
markdown = get_combined_markdown(processed)
|
||||
for page in processed.pages:
|
||||
global_page_num += 1
|
||||
image_data = {}
|
||||
for img in page.images:
|
||||
image_data[img.id] = img.image_base64
|
||||
markdown = replace_images_in_markdown(
|
||||
page.markdown, image_data
|
||||
)
|
||||
pages.append((markdown, global_page_num))
|
||||
|
||||
logger.info("OCR complete.")
|
||||
logger.info(f"OCR complete, {len(pages)} pages.")
|
||||
|
||||
return markdown
|
||||
return pages
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
|
|
@ -166,16 +357,97 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info(f"Decoding {v.metadata.id}...")
|
||||
|
||||
markdown = self.ocr(base64.b64decode(v.data))
|
||||
# Check MIME type if fetching from librarian
|
||||
if v.document_id:
|
||||
doc_meta = await self.fetch_document_metadata(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||
logger.error(
|
||||
f"Unsupported MIME type: {doc_meta.kind}. "
|
||||
f"Mistral OCR decoder only handles application/pdf. "
|
||||
f"Ignoring document {v.metadata.id}."
|
||||
)
|
||||
return
|
||||
|
||||
r = TextDocument(
|
||||
metadata=v.metadata,
|
||||
text=markdown.encode("utf-8"),
|
||||
)
|
||||
# Get PDF content - fetch from librarian or use inline data
|
||||
if v.document_id:
|
||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||
content = await self.fetch_document_content(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
blob = base64.b64decode(content)
|
||||
logger.info(f"Fetched {len(blob)} bytes from librarian")
|
||||
else:
|
||||
blob = base64.b64decode(v.data)
|
||||
|
||||
await flow("output").send(r)
|
||||
# Get the source document ID
|
||||
source_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
logger.info("Done.")
|
||||
# Run OCR, get per-page markdown
|
||||
pages = self.ocr(blob)
|
||||
|
||||
for markdown, page_num in pages:
|
||||
|
||||
logger.debug(f"Processing page {page_num}")
|
||||
|
||||
# Generate unique page ID
|
||||
pg_uri = make_page_uri()
|
||||
page_doc_id = pg_uri
|
||||
page_content = markdown.encode("utf-8")
|
||||
|
||||
# Save page as child document in librarian
|
||||
await self.save_child_document(
|
||||
doc_id=page_doc_id,
|
||||
parent_id=source_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=page_content,
|
||||
document_type="page",
|
||||
title=f"Page {page_num}",
|
||||
)
|
||||
|
||||
# Emit provenance triples
|
||||
doc_uri = document_uri(source_doc_id)
|
||||
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=pg_uri,
|
||||
parent_uri=doc_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Page {page_num}",
|
||||
page_number=page_num,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward page document ID to chunker
|
||||
# Chunker will fetch content from librarian
|
||||
r = TextDocument(
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
document_id=page_doc_id,
|
||||
text=b"", # Empty, chunker will fetch from librarian
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
logger.debug("PDF decoding complete")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -188,7 +460,18 @@ class Processor(FlowProcessor):
|
|||
help=f'Mistral API Key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri, derived_entity_triples,
|
||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ COMPONENT_VERSION = "1.0.0"
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "pdf-decoder"
|
||||
default_ident = "document-decoder"
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
|
@ -126,8 +126,39 @@ class Processor(FlowProcessor):
|
|||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document metadata from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.document_metadata
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
|
|
@ -233,6 +264,20 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info(f"Decoding PDF {v.metadata.id}...")
|
||||
|
||||
# Check MIME type if fetching from librarian
|
||||
if v.document_id:
|
||||
doc_meta = await self.fetch_document_metadata(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||
logger.error(
|
||||
f"Unsupported MIME type: {doc_meta.kind}. "
|
||||
f"PDF decoder only handles application/pdf. "
|
||||
f"Ignoring document {v.metadata.id}."
|
||||
)
|
||||
return
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
|
||||
temp_path = fp.name
|
||||
|
||||
|
|
@ -272,8 +317,9 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug(f"Processing page {page_num}")
|
||||
|
||||
# Generate page document ID
|
||||
page_doc_id = f"{source_doc_id}/p{page_num}"
|
||||
# Generate unique page ID
|
||||
pg_uri = make_page_uri()
|
||||
page_doc_id = pg_uri
|
||||
page_content = page.page_content.encode("utf-8")
|
||||
|
||||
# Save page as child document in librarian
|
||||
|
|
@ -288,7 +334,6 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
doc_uri = document_uri(source_doc_id)
|
||||
pg_uri = page_uri(source_doc_id, page_num)
|
||||
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=pg_uri,
|
||||
|
|
|
|||
|
|
@ -44,12 +44,8 @@ class Librarian:
|
|||
|
||||
async def add_document(self, request):
|
||||
|
||||
if request.document_metadata.kind not in (
|
||||
"text/plain", "application/pdf"
|
||||
):
|
||||
raise RequestError(
|
||||
"Invalid document kind: " + request.document_metadata.kind
|
||||
)
|
||||
if not request.document_metadata.kind:
|
||||
raise RequestError("Document kind (MIME type) is required")
|
||||
|
||||
if await self.table_store.document_exists(
|
||||
request.document_metadata.user,
|
||||
|
|
@ -276,10 +272,8 @@ class Librarian:
|
|||
"""
|
||||
logger.info(f"Beginning chunked upload for document {request.document_metadata.id}")
|
||||
|
||||
if request.document_metadata.kind not in ("text/plain", "application/pdf"):
|
||||
raise RequestError(
|
||||
"Invalid document kind: " + request.document_metadata.kind
|
||||
)
|
||||
if not request.document_metadata.kind:
|
||||
raise RequestError("Document kind (MIME type) is required")
|
||||
|
||||
if await self.table_store.document_exists(
|
||||
request.document_metadata.user,
|
||||
|
|
|
|||
|
|
@ -284,7 +284,6 @@ class Processor(AsyncProcessor):
|
|||
pass
|
||||
|
||||
# Threshold for sending document_id instead of inline content (2MB)
|
||||
STREAMING_THRESHOLD = 2 * 1024 * 1024
|
||||
|
||||
async def emit_document_provenance(self, document, processing, triples_queue):
|
||||
"""
|
||||
|
|
@ -360,10 +359,8 @@ class Processor(AsyncProcessor):
|
|||
|
||||
if document.kind == "text/plain":
|
||||
kind = "text-load"
|
||||
elif document.kind == "application/pdf":
|
||||
kind = "document-load"
|
||||
else:
|
||||
raise RuntimeError("Document with a MIME type I don't know")
|
||||
kind = "document-load"
|
||||
|
||||
q = flow["interfaces"][kind]
|
||||
|
||||
|
|
@ -374,57 +371,28 @@ class Processor(AsyncProcessor):
|
|||
)
|
||||
|
||||
if kind == "text-load":
|
||||
# For large text documents, send document_id for streaming retrieval
|
||||
if len(content) >= self.STREAMING_THRESHOLD:
|
||||
logger.info(f"Text document {document.id} is large ({len(content)} bytes), "
|
||||
f"sending document_id for streaming retrieval")
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
text = b"", # Empty, receiver will fetch via librarian
|
||||
)
|
||||
else:
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
text = content,
|
||||
)
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
text = b"",
|
||||
)
|
||||
schema = TextDocument
|
||||
else:
|
||||
# For large PDF documents, send document_id for streaming retrieval
|
||||
# instead of embedding the entire content in the message
|
||||
if len(content) >= self.STREAMING_THRESHOLD:
|
||||
logger.info(f"Document {document.id} is large ({len(content)} bytes), "
|
||||
f"sending document_id for streaming retrieval")
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
data = b"", # Empty data, receiver will fetch via API
|
||||
)
|
||||
else:
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
data = base64.b64encode(content).decode("utf-8")
|
||||
)
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
data = b"",
|
||||
)
|
||||
schema = Document
|
||||
|
||||
logger.debug(f"Submitting to queue {q}...")
|
||||
|
|
|
|||
|
|
@ -139,8 +139,6 @@ class Processor(FlowProcessor):
|
|||
if request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
||||
"""Fetch chunk content from librarian/Garage."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
|
|
@ -550,7 +551,8 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_limit = 25, streaming = False,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
|
|
@ -565,6 +567,8 @@ class GraphRag:
|
|||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
|
|
@ -628,6 +632,70 @@ class GraphRag:
|
|||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -49,6 +50,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -57,6 +59,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
|
|
@ -166,8 +169,6 @@ class Processor(FlowProcessor):
|
|||
if request_id and request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
|
|
@ -295,6 +296,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -330,6 +336,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -344,6 +351,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -432,6 +440,20 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the user's collection
|
||||
# with the named graph urn:graph:retrieval (no separate collection needed)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"boto3",
|
||||
|
|
|
|||
|
|
@ -2,21 +2,41 @@
|
|||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
|
||||
Supports both inline document data and fetching from librarian via Pulsar
|
||||
for large documents.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
import pytesseract
|
||||
from pdf2image import convert_from_bytes
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "tesseract-ocr-decoder"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "pdf-decoder"
|
||||
default_ident = "document-decoder"
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -45,8 +65,181 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor = id, flow = None, name = "librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend = self.pubsub,
|
||||
topic = librarian_request_q,
|
||||
schema = LibrarianRequest,
|
||||
metrics = librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor = id, flow = None, name = "librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
backend = self.pubsub,
|
||||
flow = None,
|
||||
topic = librarian_response_q,
|
||||
subscriber = f"{id}-librarian",
|
||||
schema = LibrarianResponse,
|
||||
handler = self.on_librarian_response,
|
||||
metrics = librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
logger.info("PDF OCR processor initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document metadata from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.document_metadata
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document content from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="page", title=None, timeout=120):
|
||||
"""
|
||||
Save a child document to the librarian.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
logger.info("PDF message received")
|
||||
|
|
@ -55,21 +248,99 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info(f"Decoding {v.metadata.id}...")
|
||||
|
||||
blob = base64.b64decode(v.data)
|
||||
# Check MIME type if fetching from librarian
|
||||
if v.document_id:
|
||||
doc_meta = await self.fetch_document_metadata(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||
logger.error(
|
||||
f"Unsupported MIME type: {doc_meta.kind}. "
|
||||
f"Tesseract OCR decoder only handles application/pdf. "
|
||||
f"Ignoring document {v.metadata.id}."
|
||||
)
|
||||
return
|
||||
|
||||
# Get PDF content - fetch from librarian or use inline data
|
||||
if v.document_id:
|
||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||
content = await self.fetch_document_content(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
blob = base64.b64decode(content)
|
||||
logger.info(f"Fetched {len(blob)} bytes from librarian")
|
||||
else:
|
||||
blob = base64.b64decode(v.data)
|
||||
|
||||
# Get the source document ID
|
||||
source_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
pages = convert_from_bytes(blob)
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
page_num = ix + 1 # 1-indexed
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(page, lang='eng')
|
||||
except Exception as e:
|
||||
logger.warning(f"Page did not OCR: {e}")
|
||||
logger.warning(f"Page {page_num} did not OCR: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing page {page_num}")
|
||||
|
||||
# Generate unique page ID
|
||||
pg_uri = make_page_uri()
|
||||
page_doc_id = pg_uri
|
||||
page_content = text.encode("utf-8")
|
||||
|
||||
# Save page as child document in librarian
|
||||
await self.save_child_document(
|
||||
doc_id=page_doc_id,
|
||||
parent_id=source_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=page_content,
|
||||
document_type="page",
|
||||
title=f"Page {page_num}",
|
||||
)
|
||||
|
||||
# Emit provenance triples
|
||||
doc_uri = document_uri(source_doc_id)
|
||||
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=pg_uri,
|
||||
parent_uri=doc_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Page {page_num}",
|
||||
page_number=page_num,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward page document ID to chunker
|
||||
# Chunker will fetch content from librarian
|
||||
r = TextDocument(
|
||||
metadata=v.metadata,
|
||||
text=text.encode("utf-8"),
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
document_id=page_doc_id,
|
||||
text=b"", # Empty, chunker will fetch from librarian
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
|
@ -78,9 +349,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
34
trustgraph-unstructured/pyproject.toml
Normal file
34
trustgraph-unstructured/pyproject.toml
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trustgraph-unstructured"
|
||||
dynamic = ["version"]
|
||||
authors = [{name = "trustgraph.ai", email = "security@trustgraph.ai"}]
|
||||
description = "TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"python-magic",
|
||||
"unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/trustgraph-ai/trustgraph"
|
||||
|
||||
[project.scripts]
|
||||
universal-decoder = "trustgraph.decoding.universal:run"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["trustgraph*"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "trustgraph.unstructured_version.__version__"}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . processor import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
|
@ -0,0 +1,710 @@
|
|||
|
||||
"""
|
||||
Universal document decoder powered by the unstructured library.
|
||||
|
||||
Accepts documents in any common format (PDF, DOCX, XLSX, HTML, Markdown,
|
||||
plain text, PPTX, etc.) on input, outputs pages or sections as text
|
||||
as separate output objects.
|
||||
|
||||
Supports both inline document data and fetching from librarian via Pulsar
|
||||
for large documents. Fetches document metadata from the librarian to
|
||||
determine mime type for format detection.
|
||||
|
||||
Tables are preserved as HTML markup for better downstream extraction.
|
||||
Images are stored in the librarian but not sent to the text pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import magic
|
||||
import tempfile
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri as make_page_uri,
|
||||
section_uri as make_section_uri, image_uri as make_image_uri,
|
||||
derived_entity_triples, set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
from . strategies import get_strategy
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "universal-decoder"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "document-decoder"
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
# Mime type to unstructured content_type mapping
|
||||
# unstructured auto-detects most formats, but we pass the hint when available
|
||||
MIME_EXTENSIONS = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/tab-separated-values": ".tsv",
|
||||
"application/rtf": ".rtf",
|
||||
"text/x-rst": ".rst",
|
||||
"application/vnd.oasis.opendocument.text": ".odt",
|
||||
}
|
||||
|
||||
# Formats that have natural page boundaries
|
||||
PAGE_BASED_FORMATS = {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-excel",
|
||||
}
|
||||
|
||||
|
||||
def assemble_section_text(elements):
|
||||
"""
|
||||
Assemble text from a list of unstructured elements.
|
||||
|
||||
- Text elements: plain text, joined with double newlines
|
||||
- Table elements: HTML table markup from text_as_html
|
||||
- Image elements: skipped (stored separately, not in text output)
|
||||
|
||||
Returns:
|
||||
tuple: (assembled_text, element_types_set, table_count, image_count)
|
||||
"""
|
||||
parts = []
|
||||
element_types = set()
|
||||
table_count = 0
|
||||
image_count = 0
|
||||
|
||||
for el in elements:
|
||||
category = getattr(el, 'category', 'UncategorizedText')
|
||||
element_types.add(category)
|
||||
|
||||
if category == 'Image':
|
||||
image_count += 1
|
||||
continue # Images are NOT included in text output
|
||||
|
||||
if category == 'Table':
|
||||
table_count += 1
|
||||
# Prefer HTML representation for tables
|
||||
html = getattr(el.metadata, 'text_as_html', None) if hasattr(el, 'metadata') else None
|
||||
if html:
|
||||
parts.append(html)
|
||||
else:
|
||||
# Fallback to plain text
|
||||
text = getattr(el, 'text', '') or ''
|
||||
if text:
|
||||
parts.append(text)
|
||||
else:
|
||||
text = getattr(el, 'text', '') or ''
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
return '\n\n'.join(parts), element_types, table_count, image_count
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
self.partition_strategy = params.get("strategy", "auto")
|
||||
self.languages = params.get("languages", "eng").split(",")
|
||||
self.section_strategy_name = params.get(
|
||||
"section_strategy", "whole-document"
|
||||
)
|
||||
self.section_element_count = params.get("section_element_count", 20)
|
||||
self.section_max_size = params.get("section_max_size", 4000)
|
||||
self.section_within_pages = params.get("section_within_pages", False)
|
||||
|
||||
self.section_strategy = get_strategy(self.section_strategy_name)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=Document,
|
||||
handler=self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="output",
|
||||
schema=TextDocument,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="triples",
|
||||
schema=Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching/storing document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
logger.info("Universal decoder initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def _librarian_request(self, request, timeout=120):
|
||||
"""Send a request to the librarian and wait for response."""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: "
|
||||
f"{response.error.message}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError("Timeout waiting for librarian response")
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user):
|
||||
"""Fetch document metadata from the librarian."""
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
response = await self._librarian_request(request)
|
||||
return response.document_metadata
|
||||
|
||||
async def fetch_document_content(self, document_id, user):
|
||||
"""Fetch document content from the librarian."""
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
response = await self._librarian_request(request)
|
||||
return response.content
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="page", title=None,
|
||||
kind="text/plain"):
|
||||
"""Save a child document to the librarian."""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
await self._librarian_request(request)
|
||||
return doc_id
|
||||
|
||||
def extract_elements(self, blob, mime_type=None):
|
||||
"""
|
||||
Extract elements from a document using unstructured.
|
||||
|
||||
Args:
|
||||
blob: Raw document bytes
|
||||
mime_type: Optional mime type hint
|
||||
|
||||
Returns:
|
||||
List of unstructured Element objects
|
||||
"""
|
||||
# Determine file extension for unstructured
|
||||
suffix = MIME_EXTENSIONS.get(mime_type, "") if mime_type else ""
|
||||
if not suffix:
|
||||
suffix = ".bin"
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=suffix
|
||||
) as fp:
|
||||
fp.write(blob)
|
||||
temp_path = fp.name
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"filename": temp_path,
|
||||
"strategy": self.partition_strategy,
|
||||
"languages": self.languages,
|
||||
}
|
||||
|
||||
# For hi_res strategy, request image extraction
|
||||
if self.partition_strategy == "hi_res":
|
||||
kwargs["extract_image_block_to_payload"] = True
|
||||
|
||||
elements = partition(**kwargs)
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(elements)} elements "
|
||||
f"(strategy: {self.partition_strategy})"
|
||||
)
|
||||
|
||||
return elements
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def group_by_page(self, elements):
|
||||
"""
|
||||
Group elements by page number.
|
||||
|
||||
Returns list of (page_number, elements) tuples.
|
||||
"""
|
||||
pages = {}
|
||||
|
||||
for el in elements:
|
||||
page_num = getattr(
|
||||
el.metadata, 'page_number', None
|
||||
) if hasattr(el, 'metadata') else None
|
||||
if page_num is None:
|
||||
page_num = 1
|
||||
if page_num not in pages:
|
||||
pages[page_num] = []
|
||||
pages[page_num].append(el)
|
||||
|
||||
return sorted(pages.items())
|
||||
|
||||
async def emit_section(self, elements, parent_doc_id, doc_uri_str,
|
||||
metadata, flow, mime_type=None,
|
||||
page_number=None, section_index=None):
|
||||
"""
|
||||
Process a group of elements as a page or section.
|
||||
|
||||
Assembles text, saves to librarian, emits provenance, sends
|
||||
TextDocument downstream. Returns the entity URI.
|
||||
"""
|
||||
text, element_types, table_count, image_count = (
|
||||
assemble_section_text(elements)
|
||||
)
|
||||
|
||||
if not text.strip():
|
||||
logger.debug("Skipping empty section")
|
||||
return None
|
||||
|
||||
is_page = page_number is not None
|
||||
char_length = len(text)
|
||||
|
||||
if is_page:
|
||||
entity_uri = make_page_uri()
|
||||
label = f"Page {page_number}"
|
||||
else:
|
||||
entity_uri = make_section_uri()
|
||||
label = f"Section {section_index}" if section_index else "Section"
|
||||
|
||||
doc_id = entity_uri
|
||||
page_content = text.encode("utf-8")
|
||||
|
||||
# Save to librarian
|
||||
await self.save_child_document(
|
||||
doc_id=doc_id,
|
||||
parent_id=parent_doc_id,
|
||||
user=metadata.user,
|
||||
content=page_content,
|
||||
document_type="page" if is_page else "section",
|
||||
title=label,
|
||||
)
|
||||
|
||||
# Emit provenance triples
|
||||
element_types_str = ",".join(sorted(element_types)) if element_types else None
|
||||
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=entity_uri,
|
||||
parent_uri=doc_uri_str,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=label,
|
||||
page_number=page_number,
|
||||
section=not is_page,
|
||||
char_length=char_length,
|
||||
mime_type=mime_type,
|
||||
element_types=element_types_str,
|
||||
table_count=table_count if table_count > 0 else None,
|
||||
image_count=image_count if image_count > 0 else None,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=entity_uri,
|
||||
root=metadata.root,
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Send TextDocument downstream (chunker will fetch from librarian)
|
||||
r = TextDocument(
|
||||
metadata=Metadata(
|
||||
id=entity_uri,
|
||||
root=metadata.root,
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
document_id=doc_id,
|
||||
text=b"",
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
return entity_uri
|
||||
|
||||
async def emit_image(self, element, parent_uri, parent_doc_id,
|
||||
metadata, flow, mime_type=None, page_number=None):
|
||||
"""
|
||||
Store an image element in the librarian with provenance.
|
||||
|
||||
Images are stored but NOT sent downstream to the text pipeline.
|
||||
"""
|
||||
img_uri = make_image_uri()
|
||||
|
||||
# Get image data
|
||||
img_data = None
|
||||
if hasattr(element, 'metadata'):
|
||||
img_data = getattr(element.metadata, 'image_base64', None)
|
||||
|
||||
if not img_data:
|
||||
# No image payload available, just record provenance
|
||||
logger.debug("Image element without payload, recording provenance only")
|
||||
img_content = b""
|
||||
img_kind = "image/unknown"
|
||||
else:
|
||||
if isinstance(img_data, str):
|
||||
img_content = base64.b64decode(img_data)
|
||||
else:
|
||||
img_content = img_data
|
||||
img_kind = "image/png" # unstructured typically extracts as PNG
|
||||
|
||||
# Save to librarian
|
||||
if img_content:
|
||||
await self.save_child_document(
|
||||
doc_id=img_uri,
|
||||
parent_id=parent_doc_id,
|
||||
user=metadata.user,
|
||||
content=img_content,
|
||||
document_type="image",
|
||||
title=f"Image from page {page_number}" if page_number else "Image",
|
||||
kind=img_kind,
|
||||
)
|
||||
|
||||
# Emit provenance triples
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=img_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Image from page {page_number}" if page_number else "Image",
|
||||
image=True,
|
||||
page_number=page_number,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=img_uri,
|
||||
root=metadata.root,
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
logger.debug("Document message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
logger.info(f"Decoding {v.metadata.id}...")
|
||||
|
||||
# Determine content and mime type
|
||||
mime_type = None
|
||||
|
||||
if v.document_id:
|
||||
# Librarian path: fetch metadata then content
|
||||
logger.info(
|
||||
f"Fetching document {v.document_id} from librarian..."
|
||||
)
|
||||
|
||||
doc_meta = await self.fetch_document_metadata(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
mime_type = doc_meta.kind if doc_meta else None
|
||||
|
||||
content = await self.fetch_document_content(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
blob = base64.b64decode(content)
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(blob)} bytes, mime: {mime_type}"
|
||||
)
|
||||
else:
|
||||
# Inline path: detect format from content
|
||||
blob = base64.b64decode(v.data)
|
||||
try:
|
||||
mime_type = magic.from_buffer(blob, mime=True)
|
||||
logger.info(f"Detected mime type: {mime_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect mime type: {e}")
|
||||
|
||||
# Get the source document ID
|
||||
source_doc_id = v.document_id or v.metadata.id
|
||||
doc_uri_str = document_uri(source_doc_id)
|
||||
|
||||
# Extract elements using unstructured
|
||||
elements = self.extract_elements(blob, mime_type)
|
||||
|
||||
if not elements:
|
||||
logger.warning("No elements extracted from document")
|
||||
return
|
||||
|
||||
# Determine if this is a page-based format
|
||||
is_page_based = mime_type in PAGE_BASED_FORMATS if mime_type else False
|
||||
|
||||
# Also check if elements actually have page numbers
|
||||
if not is_page_based:
|
||||
has_pages = any(
|
||||
getattr(el.metadata, 'page_number', None) is not None
|
||||
for el in elements
|
||||
if hasattr(el, 'metadata')
|
||||
)
|
||||
if has_pages:
|
||||
is_page_based = True
|
||||
|
||||
if is_page_based:
|
||||
# Group by page
|
||||
page_groups = self.group_by_page(elements)
|
||||
|
||||
for page_num, page_elements in page_groups:
|
||||
|
||||
# Extract and store images separately
|
||||
image_elements = [
|
||||
el for el in page_elements
|
||||
if getattr(el, 'category', '') == 'Image'
|
||||
]
|
||||
text_elements = [
|
||||
el for el in page_elements
|
||||
if getattr(el, 'category', '') != 'Image'
|
||||
]
|
||||
|
||||
# Emit the page as a text section
|
||||
page_uri_str = await self.emit_section(
|
||||
text_elements, source_doc_id, doc_uri_str,
|
||||
v.metadata, flow,
|
||||
mime_type=mime_type, page_number=page_num,
|
||||
)
|
||||
|
||||
# Store images (not sent to text pipeline)
|
||||
for img_el in image_elements:
|
||||
await self.emit_image(
|
||||
img_el,
|
||||
page_uri_str or doc_uri_str,
|
||||
source_doc_id,
|
||||
v.metadata, flow,
|
||||
mime_type=mime_type, page_number=page_num,
|
||||
)
|
||||
|
||||
else:
|
||||
# Non-page format: use section strategy
|
||||
|
||||
# Separate images from text elements
|
||||
image_elements = [
|
||||
el for el in elements
|
||||
if getattr(el, 'category', '') == 'Image'
|
||||
]
|
||||
text_elements = [
|
||||
el for el in elements
|
||||
if getattr(el, 'category', '') != 'Image'
|
||||
]
|
||||
|
||||
# Apply section strategy to text elements
|
||||
strategy_kwargs = {
|
||||
'element_count': self.section_element_count,
|
||||
'max_size': self.section_max_size,
|
||||
}
|
||||
groups = self.section_strategy(text_elements, **strategy_kwargs)
|
||||
|
||||
for idx, group in enumerate(groups):
|
||||
section_idx = idx + 1
|
||||
|
||||
await self.emit_section(
|
||||
group, source_doc_id, doc_uri_str,
|
||||
v.metadata, flow,
|
||||
mime_type=mime_type, section_index=section_idx,
|
||||
)
|
||||
|
||||
# Store images (not sent to text pipeline)
|
||||
for img_el in image_elements:
|
||||
await self.emit_image(
|
||||
img_el, doc_uri_str, source_doc_id,
|
||||
v.metadata, flow,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
logger.info("Document decoding complete")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--strategy',
|
||||
default='auto',
|
||||
choices=['auto', 'hi_res', 'fast'],
|
||||
help='Partitioning strategy (default: auto)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--languages',
|
||||
default='eng',
|
||||
help='Comma-separated OCR language codes (default: eng)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--section-strategy',
|
||||
default='whole-document',
|
||||
choices=[
|
||||
'whole-document', 'heading', 'element-type', 'count', 'size'
|
||||
],
|
||||
help='Section grouping strategy for non-page formats '
|
||||
'(default: whole-document)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--section-element-count',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Elements per section for count strategy (default: 20)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--section-max-size',
|
||||
type=int,
|
||||
default=4000,
|
||||
help='Max chars per section for size strategy (default: 4000)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--section-within-pages',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Apply section strategy within pages too (default: false)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue '
|
||||
f'(default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue '
|
||||
f'(default: {default_librarian_response_queue})',
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
|
||||
"""
|
||||
Section grouping strategies for the universal document decoder.
|
||||
|
||||
Each strategy takes a list of unstructured elements and returns a list
|
||||
of element groups. Each group becomes one TextDocument output.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def group_whole_document(elements, **kwargs):
|
||||
"""
|
||||
Emit the entire document as a single section.
|
||||
|
||||
The downstream chunker handles all splitting.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
return [elements]
|
||||
|
||||
|
||||
def group_by_heading(elements, **kwargs):
|
||||
"""
|
||||
Split at heading elements (Title category).
|
||||
|
||||
Each section is a heading plus all content until the next heading.
|
||||
Falls back to whole-document if no headings are found.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
# Check if any headings exist
|
||||
has_headings = any(
|
||||
getattr(el, 'category', '') == 'Title' for el in elements
|
||||
)
|
||||
|
||||
if not has_headings:
|
||||
logger.debug("No headings found, falling back to whole-document")
|
||||
return group_whole_document(elements)
|
||||
|
||||
groups = []
|
||||
current_group = []
|
||||
|
||||
for el in elements:
|
||||
if getattr(el, 'category', '') == 'Title' and current_group:
|
||||
groups.append(current_group)
|
||||
current_group = []
|
||||
current_group.append(el)
|
||||
|
||||
if current_group:
|
||||
groups.append(current_group)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def group_by_element_type(elements, **kwargs):
|
||||
"""
|
||||
Split on transitions between narrative text and tables.
|
||||
|
||||
Consecutive elements of the same broad category stay grouped.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
def is_table(el):
|
||||
return getattr(el, 'category', '') == 'Table'
|
||||
|
||||
groups = []
|
||||
current_group = []
|
||||
current_is_table = None
|
||||
|
||||
for el in elements:
|
||||
el_is_table = is_table(el)
|
||||
if current_is_table is not None and el_is_table != current_is_table:
|
||||
groups.append(current_group)
|
||||
current_group = []
|
||||
current_group.append(el)
|
||||
current_is_table = el_is_table
|
||||
|
||||
if current_group:
|
||||
groups.append(current_group)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def group_by_count(elements, element_count=20, **kwargs):
|
||||
"""
|
||||
Group a fixed number of elements per section.
|
||||
|
||||
Args:
|
||||
elements: List of unstructured elements
|
||||
element_count: Number of elements per group (default: 20)
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
groups = []
|
||||
for i in range(0, len(elements), element_count):
|
||||
groups.append(elements[i:i + element_count])
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def group_by_size(elements, max_size=4000, **kwargs):
|
||||
"""
|
||||
Accumulate elements until a character limit is reached.
|
||||
|
||||
Respects element boundaries — never splits mid-element. If a
|
||||
single element exceeds the limit, it becomes its own section.
|
||||
|
||||
Args:
|
||||
elements: List of unstructured elements
|
||||
max_size: Max characters per section (default: 4000)
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
groups = []
|
||||
current_group = []
|
||||
current_size = 0
|
||||
|
||||
for el in elements:
|
||||
el_text = getattr(el, 'text', '') or ''
|
||||
el_size = len(el_text)
|
||||
|
||||
if current_group and current_size + el_size > max_size:
|
||||
groups.append(current_group)
|
||||
current_group = []
|
||||
current_size = 0
|
||||
|
||||
current_group.append(el)
|
||||
current_size += el_size
|
||||
|
||||
if current_group:
|
||||
groups.append(current_group)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
# Strategy registry
|
||||
STRATEGIES = {
|
||||
'whole-document': group_whole_document,
|
||||
'heading': group_by_heading,
|
||||
'element-type': group_by_element_type,
|
||||
'count': group_by_count,
|
||||
'size': group_by_size,
|
||||
}
|
||||
|
||||
|
||||
def get_strategy(name):
|
||||
"""
|
||||
Get a section grouping strategy by name.
|
||||
|
||||
Args:
|
||||
name: Strategy name (whole-document, heading, element-type, count, size)
|
||||
|
||||
Returns:
|
||||
Strategy function
|
||||
|
||||
Raises:
|
||||
ValueError: If strategy name is not recognized
|
||||
"""
|
||||
if name not in STRATEGIES:
|
||||
raise ValueError(
|
||||
f"Unknown section strategy: {name}. "
|
||||
f"Available: {', '.join(STRATEGIES.keys())}"
|
||||
)
|
||||
return STRATEGIES[name]
|
||||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"trustgraph-base>=2.2,<2.3",
|
||||
"pulsar-client",
|
||||
"google-genai",
|
||||
"google-api-core",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue