mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Added Mistral OCR client (#326)
- Added Mistral OCR client - Template updates for pdf-ocr - Template updates for pdf-ocr-mistral
This commit is contained in:
parent
fe422b2b95
commit
482592b976
10 changed files with 299 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -9,6 +9,7 @@ trustgraph-base/trustgraph/base_version.py
|
|||
trustgraph-bedrock/trustgraph/bedrock_version.py
|
||||
trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py
|
||||
trustgraph-flow/trustgraph/flow_version.py
|
||||
trustgraph-ocr/trustgraph/ocr_version.py
|
||||
trustgraph-parquet/trustgraph/parquet_version.py
|
||||
trustgraph-vertexai/trustgraph/vertexai_version.py
|
||||
trustgraph-cli/trustgraph/
|
||||
|
|
|
|||
|
|
@ -37,6 +37,10 @@
|
|||
"graph-rag": import "components/graph-rag.jsonnet",
|
||||
"document-rag": import "components/document-rag.jsonnet",
|
||||
|
||||
// OCR options
|
||||
"ocr": import "components/ocr.jsonnet",
|
||||
"mistral-ocr": import "components/mistral-ocr.jsonnet",
|
||||
|
||||
// Librarian - document management
|
||||
"librarian": import "components/librarian.jsonnet",
|
||||
|
||||
|
|
|
|||
47
templates/components/mistral-ocr.jsonnet
Normal file
47
templates/components/mistral-ocr.jsonnet
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
local images = import "values/images.jsonnet";
|
||||
local url = import "values/url.jsonnet";
|
||||
|
||||
{
|
||||
|
||||
with:: function(key, value)
|
||||
self + {
|
||||
["mistral-" + key]:: value,
|
||||
},
|
||||
|
||||
"pdf-decoder" +: {
|
||||
|
||||
create:: function(engine)
|
||||
|
||||
local envSecrets = engine.envSecrets("mistral-credentials")
|
||||
.with_env_var("MISTRAL_TOKEN", "mistral-token");
|
||||
|
||||
local container =
|
||||
engine.container("mistral-ocr")
|
||||
.with_image(images.trustgraph_flow)
|
||||
.with_command([
|
||||
"pdf-ocr-mistral",
|
||||
"-p",
|
||||
url.pulsar,
|
||||
])
|
||||
.with_env_var_secrets(envSecrets)
|
||||
.with_limits("0.5", "128M")
|
||||
.with_reservations("0.1", "128M");
|
||||
|
||||
local containerSet = engine.containers(
|
||||
"mistral-ocr", [ container ]
|
||||
);
|
||||
|
||||
local service =
|
||||
engine.internalService(containerSet)
|
||||
.with_port(8080, 8080, "metrics");
|
||||
|
||||
engine.resources([
|
||||
envSecrets,
|
||||
containerSet,
|
||||
service,
|
||||
])
|
||||
|
||||
},
|
||||
|
||||
} + prompts
|
||||
|
||||
38
templates/components/ocr.jsonnet
Normal file
38
templates/components/ocr.jsonnet
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
local images = import "values/images.jsonnet";
|
||||
local url = import "values/url.jsonnet";
|
||||
|
||||
{
|
||||
|
||||
"pdf-decoder" +: {
|
||||
|
||||
create:: function(engine)
|
||||
|
||||
local container =
|
||||
engine.container("pdf-ocr")
|
||||
.with_image(images.trustgraph_ocr)
|
||||
.with_command([
|
||||
"pdf-ocr",
|
||||
"-p",
|
||||
url.pulsar,
|
||||
])
|
||||
.with_limits("1.0", "512M")
|
||||
.with_reservations("0.1", "512M");
|
||||
|
||||
local containerSet = engine.containers(
|
||||
"pdf-ocr", [ container ]
|
||||
);
|
||||
|
||||
local service =
|
||||
engine.internalService(containerSet)
|
||||
.with_port(8080, 8080, "metrics");
|
||||
|
||||
engine.resources([
|
||||
envSecrets,
|
||||
containerSet,
|
||||
service,
|
||||
])
|
||||
|
||||
},
|
||||
|
||||
} + prompts
|
||||
|
||||
|
|
@ -11,6 +11,7 @@ local version = import "version.jsonnet";
|
|||
grafana: "docker.io/grafana/grafana:11.1.4",
|
||||
trustgraph_base: "docker.io/trustgraph/trustgraph-base:" + version,
|
||||
trustgraph_flow: "docker.io/trustgraph/trustgraph-flow:" + version,
|
||||
trustgraph_ocr: "docker.io/trustgraph/trustgraph-ocr:" + version,
|
||||
trustgraph_bedrock: "docker.io/trustgraph/trustgraph-bedrock:" + version,
|
||||
trustgraph_vertexai: "docker.io/trustgraph/trustgraph-vertexai:" + version,
|
||||
trustgraph_hf: "docker.io/trustgraph/trustgraph-hf:" + version,
|
||||
|
|
|
|||
6
trustgraph-flow/scripts/pdf-ocr-mistral
Executable file
6
trustgraph-flow/scripts/pdf-ocr-mistral
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.decoding.mistral_ocr import run
|
||||
|
||||
run()
|
||||
|
||||
|
|
@ -59,6 +59,7 @@ setuptools.setup(
|
|||
"pulsar-client",
|
||||
"pymilvus",
|
||||
"pypdf",
|
||||
"mistralai",
|
||||
"pyyaml",
|
||||
"qdrant-client",
|
||||
"rdflib",
|
||||
|
|
@ -98,6 +99,7 @@ setuptools.setup(
|
|||
"scripts/object-extract-row",
|
||||
"scripts/oe-write-milvus",
|
||||
"scripts/pdf-decoder",
|
||||
"scripts/pdf-ocr-mistral",
|
||||
"scripts/prompt-generic",
|
||||
"scripts/prompt-template",
|
||||
"scripts/rows-write-cassandra",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py
Executable file
7
trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
190
trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
Executable file
190
trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
Executable file
|
|
@ -0,0 +1,190 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
"""
|
||||
|
||||
from pypdf import PdfWriter, PdfReader
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from mistralai import Mistral
|
||||
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
|
||||
from mistralai.models import OCRResponse
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
default_api_key = os.getenv("MISTRAL_TOKEN")
|
||||
|
||||
pages_per_chunk = 5
|
||||
|
||||
def chunks(lst, n):
|
||||
"Yield successive n-sized chunks from lst."
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
|
||||
def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
|
||||
"""
|
||||
Replace image placeholders in markdown with base64-encoded images.
|
||||
|
||||
Args:
|
||||
markdown_str: Markdown text containing image placeholders
|
||||
images_dict: Dictionary mapping image IDs to base64 strings
|
||||
|
||||
Returns:
|
||||
Markdown text with images replaced by base64 data
|
||||
"""
|
||||
for img_name, base64_str in images_dict.items():
|
||||
markdown_str = markdown_str.replace(
|
||||
f"", f""
|
||||
)
|
||||
return markdown_str
|
||||
|
||||
def get_combined_markdown(ocr_response: OCRResponse) -> str:
|
||||
"""
|
||||
Combine OCR text and images into a single markdown document.
|
||||
|
||||
Args:
|
||||
ocr_response: Response from OCR processing containing text and images
|
||||
|
||||
Returns:
|
||||
Combined markdown string with embedded images
|
||||
"""
|
||||
markdowns: list[str] = []
|
||||
# Extract images from page
|
||||
for page in ocr_response.pages:
|
||||
image_data = {}
|
||||
for img in page.images:
|
||||
image_data[img.id] = img.image_base64
|
||||
# Replace image placeholders with actual images
|
||||
markdowns.append(replace_images_in_markdown(page.markdown, image_data))
|
||||
|
||||
return "\n\n".join(markdowns)
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
}
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise RuntimeError("Mistral API key not specified")
|
||||
|
||||
self.mistral = Mistral(api_key=api_key)
|
||||
|
||||
# Used with Mistral doc upload
|
||||
self.unique_id = str(uuid.uuid4())
|
||||
|
||||
print("PDF inited")
|
||||
|
||||
def ocr(self, blob):
|
||||
|
||||
print("Parse PDF...", flush=True)
|
||||
|
||||
pdfbuf = BytesIO(blob)
|
||||
pdf = PdfReader(pdfbuf)
|
||||
|
||||
for chunk in chunks(pdf.pages, pages_per_chunk):
|
||||
|
||||
print("Get next pages...", flush=True)
|
||||
|
||||
part = PdfWriter()
|
||||
for page in chunk:
|
||||
part.add_page(page)
|
||||
|
||||
buf = BytesIO()
|
||||
part.write_stream(buf)
|
||||
|
||||
print("Upload chunk...", flush=True)
|
||||
|
||||
uploaded_file = self.mistral.files.upload(
|
||||
file={
|
||||
"file_name": self.unique_id,
|
||||
"content": buf.getvalue(),
|
||||
},
|
||||
purpose="ocr",
|
||||
)
|
||||
|
||||
signed_url = self.mistral.files.get_signed_url(
|
||||
file_id=uploaded_file.id, expiry=1
|
||||
)
|
||||
|
||||
print("OCR...", flush=True)
|
||||
|
||||
processed = self.mistral.ocr.process(
|
||||
model="mistral-ocr-latest",
|
||||
include_image_base64=True,
|
||||
document={
|
||||
"type": "document_url",
|
||||
"document_url": signed_url.url,
|
||||
}
|
||||
)
|
||||
|
||||
print("Extract markdown...", flush=True)
|
||||
|
||||
markdown = get_combined_markdown(processed)
|
||||
|
||||
print("OCR complete.", flush=True)
|
||||
|
||||
return markdown
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
print("PDF message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.metadata.id}...", flush=True)
|
||||
|
||||
markdown = self.ocr(base64.b64decode(v.data))
|
||||
|
||||
r = TextDocument(
|
||||
metadata=v.metadata,
|
||||
text=markdown.encode("utf-8"),
|
||||
)
|
||||
|
||||
await self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=default_api_key,
|
||||
help=f'Mistral API Key'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue