diff --git a/.gitignore b/.gitignore index 357ecf1e..4d089211 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ trustgraph-base/trustgraph/base_version.py trustgraph-bedrock/trustgraph/bedrock_version.py trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py trustgraph-flow/trustgraph/flow_version.py +trustgraph-ocr/trustgraph/ocr_version.py trustgraph-parquet/trustgraph/parquet_version.py trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-cli/trustgraph/ diff --git a/templates/components.jsonnet b/templates/components.jsonnet index 121bd6a5..d0df569f 100644 --- a/templates/components.jsonnet +++ b/templates/components.jsonnet @@ -37,6 +37,10 @@ "graph-rag": import "components/graph-rag.jsonnet", "document-rag": import "components/document-rag.jsonnet", + // OCR options + "ocr": import "components/ocr.jsonnet", + "mistral-ocr": import "components/mistral-ocr.jsonnet", + // Librarian - document management "librarian": import "components/librarian.jsonnet", diff --git a/templates/components/mistral-ocr.jsonnet b/templates/components/mistral-ocr.jsonnet new file mode 100644 index 00000000..8049c514 --- /dev/null +++ b/templates/components/mistral-ocr.jsonnet @@ -0,0 +1,47 @@ +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-" + key]:: value, + }, + + "pdf-decoder" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local container = + engine.container("mistral-ocr") + .with_image(images.trustgraph_flow) + .with_command([ + "pdf-ocr-mistral", + "-p", + url.pulsar, + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "mistral-ocr", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/components/ocr.jsonnet b/templates/components/ocr.jsonnet new file mode 100644 index 00000000..4353b7f9 --- /dev/null +++ b/templates/components/ocr.jsonnet @@ -0,0 +1,38 @@ +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; + +{ + + "pdf-decoder" +: { + + create:: function(engine) + + local container = + engine.container("pdf-ocr") + .with_image(images.trustgraph_ocr) + .with_command([ + "pdf-ocr", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.1", "512M"); + + local containerSet = engine.containers( + "pdf-ocr", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/values/images.jsonnet b/templates/values/images.jsonnet index dde235ce..54dbd016 100644 --- a/templates/values/images.jsonnet +++ b/templates/values/images.jsonnet @@ -11,6 +11,7 @@ local version = import "version.jsonnet"; grafana: "docker.io/grafana/grafana:11.1.4", trustgraph_base: "docker.io/trustgraph/trustgraph-base:" + version, trustgraph_flow: "docker.io/trustgraph/trustgraph-flow:" + version, + trustgraph_ocr: "docker.io/trustgraph/trustgraph-ocr:" + version, trustgraph_bedrock: "docker.io/trustgraph/trustgraph-bedrock:" + version, trustgraph_vertexai: "docker.io/trustgraph/trustgraph-vertexai:" + version, trustgraph_hf: "docker.io/trustgraph/trustgraph-hf:" + version, diff --git a/trustgraph-flow/scripts/pdf-ocr-mistral b/trustgraph-flow/scripts/pdf-ocr-mistral new file mode 100755 index 00000000..fb086767 --- /dev/null +++ b/trustgraph-flow/scripts/pdf-ocr-mistral @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.decoding.mistral_ocr import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 3e8a65fa..4b6179b6 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -59,6 +59,7 @@ setuptools.setup( "pulsar-client", "pymilvus", "pypdf", + "mistralai", "pyyaml", "qdrant-client", "rdflib", @@ -98,6 +99,7 @@ setuptools.setup( "scripts/object-extract-row", "scripts/oe-write-milvus", "scripts/pdf-decoder", + "scripts/pdf-ocr-mistral", "scripts/prompt-generic", "scripts/prompt-template", "scripts/rows-write-cassandra", diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py new file mode 100644 index 00000000..9d16af90 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py @@ -0,0 +1,3 @@ + +from . processor import * + diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py new file mode 100755 index 00000000..986c0257 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py new file mode 100755 index 00000000..f5100244 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -0,0 +1,190 @@ + +""" +Simple decoder, accepts PDF documents on input, outputs pages from the +PDF document as text as separate output objects. +""" + +from pypdf import PdfWriter, PdfReader +from io import BytesIO +import base64 +import uuid +import os + +from mistralai import Mistral +from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk +from mistralai.models import OCRResponse + +from ... schema import Document, TextDocument, Metadata +from ... schema import document_ingest_queue, text_ingest_queue +from ... log_level import LogLevel +from ... base import ConsumerProducer + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = document_ingest_queue +default_output_queue = text_ingest_queue +default_subscriber = module +default_api_key = os.getenv("MISTRAL_TOKEN") + +pages_per_chunk = 5 + +def chunks(lst, n): + "Yield successive n-sized chunks from lst." + for i in range(0, len(lst), n): + yield lst[i:i + n] + +def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: + """ + Replace image placeholders in markdown with base64-encoded images. + + Args: + markdown_str: Markdown text containing image placeholders + images_dict: Dictionary mapping image IDs to base64 strings + + Returns: + Markdown text with images replaced by base64 data + """ + for img_name, base64_str in images_dict.items(): + markdown_str = markdown_str.replace( + f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})" + ) + return markdown_str + +def get_combined_markdown(ocr_response: OCRResponse) -> str: + """ + Combine OCR text and images into a single markdown document. + + Args: + ocr_response: Response from OCR processing containing text and images + + Returns: + Combined markdown string with embedded images + """ + markdowns: list[str] = [] + # Extract images from page + for page in ocr_response.pages: + image_data = {} + for img in page.images: + image_data[img.id] = img.image_base64 + # Replace image placeholders with actual images + markdowns.append(replace_images_in_markdown(page.markdown, image_data)) + + return "\n\n".join(markdowns) + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + api_key = params.get("api_key", default_api_key) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": Document, + "output_schema": TextDocument, + } + ) + + if api_key is None: + raise RuntimeError("Mistral API key not specified") + + self.mistral = Mistral(api_key=api_key) + + # Used with Mistral doc upload + self.unique_id = str(uuid.uuid4()) + + print("PDF inited") + + def ocr(self, blob): + + print("Parse PDF...", flush=True) + + pdfbuf = BytesIO(blob) + pdf = PdfReader(pdfbuf) + + for chunk in chunks(pdf.pages, pages_per_chunk): + + print("Get next pages...", flush=True) + + part = PdfWriter() + for page in chunk: + part.add_page(page) + + buf = BytesIO() + part.write_stream(buf) + + print("Upload chunk...", flush=True) + + uploaded_file = self.mistral.files.upload( + file={ + "file_name": self.unique_id, + "content": buf.getvalue(), + }, + purpose="ocr", + ) + + signed_url = self.mistral.files.get_signed_url( + file_id=uploaded_file.id, expiry=1 + ) + + print("OCR...", flush=True) + + processed = self.mistral.ocr.process( + model="mistral-ocr-latest", + include_image_base64=True, + document={ + "type": "document_url", + "document_url": signed_url.url, + } + ) + + print("Extract markdown...", flush=True) + + markdown = get_combined_markdown(processed) + + print("OCR complete.", flush=True) + + return markdown + + async def handle(self, msg): + + print("PDF message received") + + v = msg.value() + + print(f"Decoding {v.metadata.id}...", flush=True) + + markdown = self.ocr(base64.b64decode(v.data)) + + r = TextDocument( + metadata=v.metadata, + text=markdown.encode("utf-8"), + ) + + await self.send(r) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-k', '--api-key', + default=default_api_key, + help=f'Mistral API Key' + ) + +def run(): + + Processor.launch(module, __doc__) +