diff --git a/Containerfile b/Containerfile index ee4e9fff..45a152c0 100644 --- a/Containerfile +++ b/Containerfile @@ -15,7 +15,7 @@ RUN pip3 install torch==2.5.1+cpu \ --index-url https://download.pytorch.org/whl/cpu RUN pip3 install \ - anthropic boto3 cohere openai google-cloud-aiplatform \ + anthropic boto3 cohere mistralai openai google-cloud-aiplatform \ ollama google-generativeai \ langchain==0.3.13 langchain-core==0.3.28 langchain-huggingface==0.1.2 \ langchain-text-splitters==0.3.4 \ diff --git a/Makefile b/Makefile index 80cdefc7..0defca58 100644 --- a/Makefile +++ b/Makefile @@ -68,13 +68,13 @@ clean: set-version: echo '"${VERSION}"' > templates/values/version.jsonnet -TEMPLATES=azure bedrock claude cohere mix llamafile ollama openai vertexai \ +TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \ openai-neo4j storage DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml}) -MODELS=azure bedrock claude cohere llamafile ollama openai vertexai -GRAPHS=cassandra neo4j falkordb +MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai +GRAPHS=cassandra neo4j falkordb memgraph # tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet # jsonnet -Jtemplates \ diff --git a/containers/Containerfile.flow b/containers/Containerfile.flow index 8d47effe..352e5ac5 100644 --- a/containers/Containerfile.flow +++ b/containers/Containerfile.flow @@ -12,7 +12,7 @@ RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ python3-rdflib RUN pip3 install --no-cache-dir \ - anthropic cohere openai google-generativeai \ + anthropic cohere mistralai openai google-generativeai \ ollama \ langchain==0.3.13 langchain-core==0.3.28 \ langchain-text-splitters==0.3.4 \ diff --git a/templates/all-patterns.jsonnet b/templates/all-patterns.jsonnet index f68f307d..3282be53 100644 --- a/templates/all-patterns.jsonnet +++ b/templates/all-patterns.jsonnet @@ -13,6 +13,7 @@ import "patterns/llm-claude.jsonnet", import "patterns/llm-cohere.jsonnet", import "patterns/llm-llamafile.jsonnet", + import "patterns/llm-mistral.jsonnet", import "patterns/llm-ollama.jsonnet", import "patterns/llm-openai.jsonnet", import "patterns/llm-vertexai.jsonnet", diff --git a/templates/components.jsonnet b/templates/components.jsonnet index 19a52206..ee2ae881 100644 --- a/templates/components.jsonnet +++ b/templates/components.jsonnet @@ -11,6 +11,7 @@ "claude": import "components/claude.jsonnet", "cohere": import "components/cohere.jsonnet", "googleaistudio": import "components/googleaistudio.jsonnet", + "mistral": import "components/mistral.jsonnet", "ollama": import "components/ollama.jsonnet", "openai": import "components/openai.jsonnet", "vertexai": import "components/vertexai.jsonnet", @@ -22,6 +23,7 @@ "claude-rag": import "components/claude-rag.jsonnet", "cohere-rag": import "components/cohere-rag.jsonnet", "googleaistudio-rag": import "components/googleaistudio-rag.jsonnet", + "mistral-rag": import "components/mistral-rag.jsonnet", "ollama-rag": import "components/ollama-rag.jsonnet", "openai-rag": import "components/openai-rag.jsonnet", "vertexai-rag": import "components/vertexai-rag.jsonnet", diff --git a/templates/components/mistral-rag.jsonnet b/templates/components/mistral-rag.jsonnet new file mode 100644 index 00000000..12fbe8a5 --- /dev/null +++ b/templates/components/mistral-rag.jsonnet @@ -0,0 +1,63 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-rag-" + key]:: value, + }, + + "mistral-rag-max-output-tokens":: 4096, + "mistral-rag-temperature":: 0.0, + "mistral-rag-model":: "ministral-8b-latest", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-mistral", + "-p", + url.pulsar, + "-x", + std.toString($["mistral-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["mistral-rag-temperature"], + "-m", + $["mistral-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/mistral.jsonnet b/templates/components/mistral.jsonnet new file mode 100644 index 00000000..4de332c9 --- /dev/null +++ b/templates/components/mistral.jsonnet @@ -0,0 +1,59 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-" + key]:: value, + }, + + "mistral-max-output-tokens":: 4096, + "mistral-temperature":: 0.0, + "mistral-model":: "ministral-8b-latest", + + "text-completion" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local container = + engine.container("text-completion") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-mistral", + "-p", + url.pulsar, + "-x", + std.toString($["mistral-max-output-tokens"]), + "-t", + "%0.3f" % $["mistral-temperature"], + "-m", + $["mistral-model"], + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "text-completion", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/generate-all b/templates/generate-all index 22c9a5b0..fb1fe917 100755 --- a/templates/generate-all +++ b/templates/generate-all @@ -134,7 +134,7 @@ def generate_all(output, version): ]: for model in [ # "azure", "azure-openai", "bedrock", "claude", "cohere", - # "googleaistudio", "llamafile", + # "googleaistudio", "llamafile", "mistral", "ollama", # "openai", "vertexai", ]: diff --git a/templates/patterns/llm-mistral.jsonnet b/templates/patterns/llm-mistral.jsonnet new file mode 100644 index 00000000..11f6de22 --- /dev/null +++ b/templates/patterns/llm-mistral.jsonnet @@ -0,0 +1,32 @@ +{ + pattern: { + name: "mistral", + icon: "🤖💬", + title: "Add Mistral LLM endpoint for text completion", + description: "This pattern integrates a Mistral LLM service for text completion operations. You need a Mistral subscription and have an API key to be able to use this service.", + requires: ["pulsar", "trustgraph"], + features: ["llm"], + args: [ + { + name: "mistral-max-output-tokens", + label: "Maximum output tokens", + type: "integer", + description: "Limit on number tokens to generate", + default: 4096, + required: true, + }, + { + name: "mistral-temperature", + label: "Temperature", + type: "slider", + description: "Controlling predictability / creativity balance", + min: 0, + max: 1, + step: 0.05, + default: 0.5, + }, + ], + category: [ "llm" ], + }, + module: "components/mistral.jsonnet", +} diff --git a/trustgraph-flow/scripts/text-completion-mistral b/trustgraph-flow/scripts/text-completion-mistral new file mode 100755 index 00000000..91ef2279 --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-mistral @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.mistral import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index fe167e90..504499f1 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -50,6 +50,7 @@ setuptools.setup( "langchain-core", "langchain-text-splitters", "minio", + "mistralai" "neo4j", "ollama", "openai", @@ -107,6 +108,7 @@ setuptools.setup( "scripts/text-completion-cohere", "scripts/text-completion-googleaistudio", "scripts/text-completion-llamafile", + "scripts/text-completion-mistral", "scripts/text-completion-ollama", "scripts/text-completion-openai", "scripts/triples-query-cassandra", diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py new file mode 100755 index 00000000..8130cf8a --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -0,0 +1,201 @@ + +""" +Simple LLM service, performs text prompt completion using Mistral. +Input is prompt, output is response. +""" + +from mistralai import Mistral, RateLimitError +from prometheus_client import Histogram +import os + +from .... schema import TextCompletionRequest, TextCompletionResponse, Error +from .... schema import text_completion_request_queue +from .... schema import text_completion_response_queue +from .... log_level import LogLevel +from .... base import ConsumerProducer +from .... exceptions import TooManyRequests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = text_completion_request_queue +default_output_queue = text_completion_response_queue +default_subscriber = module +default_model = 'ministral-8b-latest' +default_temperature = 0.0 +default_max_output = 4096 +default_api_key = os.getenv("MISTRAL_TOKEN") + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) + api_key = params.get("api_key", default_api_key) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + + if api_key is None: + raise RuntimeError("Mistral API key not specified") + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": TextCompletionRequest, + "output_schema": TextCompletionResponse, + "model": model, + "temperature": temperature, + "max_output": max_output, + } + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + self.model = model + self.temperature = temperature + self.max_output = max_output + self.mistral = Mistral(api_key=api_key) + + print("Initialised", flush=True) + + async def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling prompt {id}...", flush=True) + + prompt = v.system + "\n\n" + v.prompt + + try: + + with __class__.text_completion_metric.time(): + + resp = self.mistral.chat.complete( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ], + temperature=self.temperature, + max_tokens=self.max_output, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + response_format={ + "type": "text" + } + ) + + inputtokens = resp.usage.prompt_tokens + outputtokens = resp.usage.completion_tokens + print(resp.choices[0].message.content, flush=True) + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + + print("Send response...", flush=True) + r = TextCompletionResponse( + response=resp.choices[0].message.content, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model + ) + await self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except Mistral.RateLimitError: + + # Leave rate limit retries to the base handler + raise TooManyRequests() + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + await self.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model (default: ministral-8b-latest)' + ) + + parser.add_argument( + '-k', '--api-key', + default=default_api_key, + help=f'Mistral API Key' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + + Processor.launch(module, __doc__) + +