Added Mistral OCR client (#326)

- Added Mistral OCR client
- Template updates for pdf-ocr
- Template updates for pdf-ocr-mistral
This commit is contained in:
cybermaggedon 2025-03-22 00:27:20 +00:00 committed by GitHub
parent fe422b2b95
commit 482592b976
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 299 additions and 0 deletions

1
.gitignore vendored
View file

@ -9,6 +9,7 @@ trustgraph-base/trustgraph/base_version.py
trustgraph-bedrock/trustgraph/bedrock_version.py
trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py
trustgraph-flow/trustgraph/flow_version.py
trustgraph-ocr/trustgraph/ocr_version.py
trustgraph-parquet/trustgraph/parquet_version.py
trustgraph-vertexai/trustgraph/vertexai_version.py
trustgraph-cli/trustgraph/

View file

@ -37,6 +37,10 @@
"graph-rag": import "components/graph-rag.jsonnet",
"document-rag": import "components/document-rag.jsonnet",
// OCR options
"ocr": import "components/ocr.jsonnet",
"mistral-ocr": import "components/mistral-ocr.jsonnet",
// Librarian - document management
"librarian": import "components/librarian.jsonnet",

View file

@ -0,0 +1,47 @@
local images = import "values/images.jsonnet";
local url = import "values/url.jsonnet";
{
with:: function(key, value)
self + {
["mistral-" + key]:: value,
},
"pdf-decoder" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("mistral-credentials")
.with_env_var("MISTRAL_TOKEN", "mistral-token");
local container =
engine.container("mistral-ocr")
.with_image(images.trustgraph_flow)
.with_command([
"pdf-ocr-mistral",
"-p",
url.pulsar,
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"mistral-ocr", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
},
} + prompts

View file

@ -0,0 +1,38 @@
local images = import "values/images.jsonnet";
local url = import "values/url.jsonnet";
{
"pdf-decoder" +: {
create:: function(engine)
local container =
engine.container("pdf-ocr")
.with_image(images.trustgraph_ocr)
.with_command([
"pdf-ocr",
"-p",
url.pulsar,
])
.with_limits("1.0", "512M")
.with_reservations("0.1", "512M");
local containerSet = engine.containers(
"pdf-ocr", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
},
} + prompts

View file

@ -11,6 +11,7 @@ local version = import "version.jsonnet";
grafana: "docker.io/grafana/grafana:11.1.4",
trustgraph_base: "docker.io/trustgraph/trustgraph-base:" + version,
trustgraph_flow: "docker.io/trustgraph/trustgraph-flow:" + version,
trustgraph_ocr: "docker.io/trustgraph/trustgraph-ocr:" + version,
trustgraph_bedrock: "docker.io/trustgraph/trustgraph-bedrock:" + version,
trustgraph_vertexai: "docker.io/trustgraph/trustgraph-vertexai:" + version,
trustgraph_hf: "docker.io/trustgraph/trustgraph-hf:" + version,

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.decoding.mistral_ocr import run
run()

View file

@ -59,6 +59,7 @@ setuptools.setup(
"pulsar-client",
"pymilvus",
"pypdf",
"mistralai",
"pyyaml",
"qdrant-client",
"rdflib",
@ -98,6 +99,7 @@ setuptools.setup(
"scripts/object-extract-row",
"scripts/oe-write-milvus",
"scripts/pdf-decoder",
"scripts/pdf-ocr-mistral",
"scripts/prompt-generic",
"scripts/prompt-template",
"scripts/rows-write-cassandra",

View file

@ -0,0 +1,3 @@
from . processor import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,190 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
"""
from pypdf import PdfWriter, PdfReader
from io import BytesIO
import base64
import uuid
import os
from mistralai import Mistral
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
default_api_key = os.getenv("MISTRAL_TOKEN")
pages_per_chunk = 5
def chunks(lst, n):
"Yield successive n-sized chunks from lst."
for i in range(0, len(lst), n):
yield lst[i:i + n]
def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
"""
Replace image placeholders in markdown with base64-encoded images.
Args:
markdown_str: Markdown text containing image placeholders
images_dict: Dictionary mapping image IDs to base64 strings
Returns:
Markdown text with images replaced by base64 data
"""
for img_name, base64_str in images_dict.items():
markdown_str = markdown_str.replace(
f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})"
)
return markdown_str
def get_combined_markdown(ocr_response: OCRResponse) -> str:
"""
Combine OCR text and images into a single markdown document.
Args:
ocr_response: Response from OCR processing containing text and images
Returns:
Combined markdown string with embedded images
"""
markdowns: list[str] = []
# Extract images from page
for page in ocr_response.pages:
image_data = {}
for img in page.images:
image_data[img.id] = img.image_base64
# Replace image placeholders with actual images
markdowns.append(replace_images_in_markdown(page.markdown, image_data))
return "\n\n".join(markdowns)
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
}
)
if api_key is None:
raise RuntimeError("Mistral API key not specified")
self.mistral = Mistral(api_key=api_key)
# Used with Mistral doc upload
self.unique_id = str(uuid.uuid4())
print("PDF inited")
def ocr(self, blob):
print("Parse PDF...", flush=True)
pdfbuf = BytesIO(blob)
pdf = PdfReader(pdfbuf)
for chunk in chunks(pdf.pages, pages_per_chunk):
print("Get next pages...", flush=True)
part = PdfWriter()
for page in chunk:
part.add_page(page)
buf = BytesIO()
part.write_stream(buf)
print("Upload chunk...", flush=True)
uploaded_file = self.mistral.files.upload(
file={
"file_name": self.unique_id,
"content": buf.getvalue(),
},
purpose="ocr",
)
signed_url = self.mistral.files.get_signed_url(
file_id=uploaded_file.id, expiry=1
)
print("OCR...", flush=True)
processed = self.mistral.ocr.process(
model="mistral-ocr-latest",
include_image_base64=True,
document={
"type": "document_url",
"document_url": signed_url.url,
}
)
print("Extract markdown...", flush=True)
markdown = get_combined_markdown(processed)
print("OCR complete.", flush=True)
return markdown
async def handle(self, msg):
print("PDF message received")
v = msg.value()
print(f"Decoding {v.metadata.id}...", flush=True)
markdown = self.ocr(base64.b64decode(v.data))
r = TextDocument(
metadata=v.metadata,
text=markdown.encode("utf-8"),
)
await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-k', '--api-key',
default=default_api_key,
help=f'Mistral API Key'
)
def run():
Processor.launch(module, __doc__)