trustgraph/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
cybermaggedon 6b1dd16f9f
fix: large document handling and Cassandra query pagination (#969)
- Paginate heavy Cassandra reads (triples, graph/document embeddings)
  using synchronous session.execute() in run_in_executor with fetch_size
  paging, preventing materialization hang on large result sets
- Fix document stream endpoint to use workspace-scoped librarian queues
- Add decoder error handling for PDF/OCR/unstructured processors
- Add WebSocket mux guards for missing auth fields
- Add null check in librarian document streaming
- Rewrite get_document_content CLI to stream via librarian
- Add Poppler dependency to unstructured container
2026-06-01 22:39:30 +01:00

300 lines
8.3 KiB
Python
Executable file

"""
Mistral OCR decoder, accepts PDF documents on input, outputs pages from the
PDF document as markdown text as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents.
"""
from pypdf import PdfWriter, PdfReader
from io import BytesIO
import base64
import uuid
import os
from mistralai import Mistral
from ... schema import Document, TextDocument, Metadata
from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
import logging
logger = logging.getLogger(__name__)
# Component identification for provenance
COMPONENT_NAME = "mistral-ocr-decoder"
COMPONENT_VERSION = "1.0.0"
default_ident = "document-decoder"
default_api_key = os.getenv("MISTRAL_TOKEN")
pages_per_chunk = 5
def chunks(lst, n):
"Yield successive n-sized chunks from lst."
for i in range(0, len(lst), n):
yield lst[i:i + n]
def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
"""
Replace image placeholders in markdown with base64-encoded images.
Args:
markdown_str: Markdown text containing image placeholders
images_dict: Dictionary mapping image IDs to base64 strings
Returns:
Markdown text with images replaced by base64 data
"""
for img_name, base64_str in images_dict.items():
markdown_str = markdown_str.replace(
f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})"
)
return markdown_str
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__(
**params | {
"id": id,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Document,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = TextDocument,
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
self.register_specification(
LibrarianSpec()
)
if api_key is None:
raise RuntimeError("Mistral API key not specified")
self.mistral = Mistral(api_key=api_key)
# Used with Mistral doc upload
self.unique_id = str(uuid.uuid4())
logger.info("Mistral OCR processor initialized")
def ocr(self, blob):
"""
Run Mistral OCR on a PDF blob, returning per-page markdown strings.
Args:
blob: Raw PDF bytes
Returns:
List of (page_markdown, page_number) tuples, 1-indexed
"""
logger.debug("Parse PDF...")
pdfbuf = BytesIO(blob)
pdf = PdfReader(pdfbuf)
pages = []
global_page_num = 0
for chunk in chunks(pdf.pages, pages_per_chunk):
logger.debug("Get next pages...")
part = PdfWriter()
for page in chunk:
part.add_page(page)
buf = BytesIO()
part.write_stream(buf)
logger.debug("Upload chunk...")
uploaded_file = self.mistral.files.upload(
file={
"file_name": self.unique_id,
"content": buf.getvalue(),
},
purpose="ocr",
)
signed_url = self.mistral.files.get_signed_url(
file_id=uploaded_file.id, expiry=1
)
logger.debug("OCR...")
processed = self.mistral.ocr.process(
model="mistral-ocr-latest",
include_image_base64=True,
document={
"type": "document_url",
"document_url": signed_url.url,
}
)
logger.debug("Extract markdown...")
for page in processed.pages:
global_page_num += 1
image_data = {}
for img in page.images:
image_data[img.id] = img.image_base64
markdown = replace_images_in_markdown(
page.markdown, image_data
)
pages.append((markdown, global_page_num))
logger.info(f"OCR complete, {len(pages)} pages.")
return pages
async def on_message(self, msg, consumer, flow):
logger.debug("PDF message received")
v = msg.value()
logger.info(f"Decoding {v.metadata.id}...")
# Check MIME type if fetching from librarian
if v.document_id:
doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
f"Unsupported MIME type: {doc_meta.kind}. "
f"Mistral OCR decoder only handles application/pdf. "
f"Ignoring document {v.metadata.id}."
)
return
# Get PDF content - fetch from librarian or use inline data
if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
)
if isinstance(content, str):
content = content.encode('utf-8')
blob = base64.b64decode(content)
logger.info(f"Fetched {len(blob)} bytes from librarian")
else:
blob = base64.b64decode(v.data)
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
# Run OCR, get per-page markdown
try:
pages = self.ocr(blob)
except Exception as e:
logger.error(
f"Failed to decode PDF {source_doc_id}: "
f"{type(e).__name__}: {e}"
)
return
for markdown, page_num in pages:
logger.debug(f"Processing page {page_num}")
# Generate unique page ID
pg_uri = make_page_uri()
page_doc_id = pg_uri
page_content = markdown.encode("utf-8")
# Save page as child document in librarian
await flow.librarian.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
content=page_content,
document_type="page",
title=f"Page {page_num}",
)
# Emit provenance triples
doc_uri = document_uri(source_doc_id)
prov_triples = derived_entity_triples(
entity_uri=pg_uri,
parent_uri=doc_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Page {page_num}",
page_number=page_num,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward page document ID to chunker
# Chunker will fetch content from librarian
r = TextDocument(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
collection=v.metadata.collection,
),
document_id=page_doc_id,
text=b"", # Empty, chunker will fetch from librarian
)
await flow("output").send(r)
logger.debug("PDF decoding complete")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'-k', '--api-key',
default=default_api_key,
help=f'Mistral API Key'
)
def run():
Processor.launch(default_ident, __doc__)