diff --git a/trustgraph-base/trustgraph/base/llm_service.py b/trustgraph-base/trustgraph/base/llm_service.py index 61e48d2f..9fb31a66 100644 --- a/trustgraph-base/trustgraph/base/llm_service.py +++ b/trustgraph-base/trustgraph/base/llm_service.py @@ -80,11 +80,6 @@ class LlmService(FlowProcessor): try: - try: - logger.debug(f"MODEL IS {flow('model')}") - except: - logger.debug(f"CAN'T GET MODEL") - request = msg.value() # Sender-produced ID @@ -96,8 +91,10 @@ class LlmService(FlowProcessor): flow=f"{flow.name}-{consumer.name}", ).time(): + model = flow("model") + response = await self.generate_content( - request.system, request.prompt + request.system, request.prompt, model ) await flow("response").send( diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 292a2282..25584780 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -183,13 +183,13 @@ class Processor(LlmService): } ) - self.model = model + # Store default configuration + self.default_model = model self.temperature = temperature self.max_output = max_output - self.variant = self.determine_variant(self.model)() - self.variant.set_temperature(temperature) - self.variant.set_max_output(max_output) + # Cache for model variants to avoid re-initialization + self.model_variants = {} self.session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -208,47 +208,66 @@ class Processor(LlmService): # FIXME: Missing, Amazon models, Deepseek # This set of conditions deals with normal bedrock on-demand usage - if self.model.startswith("mistral"): + if model.startswith("mistral"): return Mistral - elif self.model.startswith("meta"): + elif model.startswith("meta"): return Meta - elif self.model.startswith("anthropic"): + elif model.startswith("anthropic"): return Anthropic - elif self.model.startswith("ai21"): + elif model.startswith("ai21"): return Ai21 - elif self.model.startswith("cohere"): + elif model.startswith("cohere"): return Cohere # The inference profiles - if self.model.startswith("us.meta"): + if model.startswith("us.meta"): return Meta - elif self.model.startswith("us.anthropic"): + elif model.startswith("us.anthropic"): return Anthropic - elif self.model.startswith("eu.meta"): + elif model.startswith("eu.meta"): return Meta - elif self.model.startswith("eu.anthropic"): + elif model.startswith("eu.anthropic"): return Anthropic return Default - async def generate_content(self, system, prompt): + def _get_or_create_variant(self, model_name): + """Get cached model variant or create new one""" + if model_name not in self.model_variants: + logger.info(f"Creating model variant for '{model_name}'") + variant_class = self.determine_variant(model_name) + variant = variant_class() + variant.set_temperature(self.temperature) + variant.set_max_output(self.max_output) + self.model_variants[model_name] = variant + + return self.model_variants[model_name] + + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") try: + # Get the appropriate variant for this model + variant = self._get_or_create_variant(model_name) - promptbody = self.variant.encode_request(system, prompt) + promptbody = variant.encode_request(system, prompt) accept = 'application/json' contentType = 'application/json' response = self.bedrock.invoke_model( body=promptbody, - modelId=self.model, + modelId=model_name, accept=accept, contentType=contentType ) # Response structure decode - outputtext = self.variant.decode_response(response) + outputtext = variant.decode_response(response) metadata = response['ResponseMetadata']['HTTPHeaders'] inputtokens = int(metadata['x-amzn-bedrock-input-token-count']) @@ -262,7 +281,7 @@ class Processor(LlmService): text = outputtext, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index 388ac7c1..af3dfcaf 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -32,7 +32,7 @@ class Processor(LlmService): token = params.get("token", default_token) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) - model = default_model + model = params.get("model", default_model) if endpoint is None: raise RuntimeError("Azure endpoint not specified") @@ -53,7 +53,7 @@ class Processor(LlmService): self.token = token self.temperature = temperature self.max_output = max_output - self.model = model + self.default_model = model def build_prompt(self, system, content): @@ -100,7 +100,12 @@ class Processor(LlmService): return result - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") try: @@ -125,7 +130,7 @@ class Processor(LlmService): text = resp, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index 11376426..f0588e5a 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -54,7 +54,7 @@ class Processor(LlmService): self.temperature = temperature self.max_output = max_output - self.model = model + self.default_model = model self.openai = AzureOpenAI( api_key=token, @@ -62,14 +62,19 @@ class Processor(LlmService): azure_endpoint = endpoint, ) - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt try: resp = self.openai.chat.completions.create( - model=self.model, + model=model_name, messages=[ { "role": "user", @@ -97,7 +102,7 @@ class Processor(LlmService): text = resp.choices[0].message.content, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return r diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py index 87b611f4..b6180038 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py @@ -41,19 +41,24 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.claude = anthropic.Anthropic(api_key=api_key) self.temperature = temperature self.max_output = max_output logger.info("Claude LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") try: response = message = self.claude.messages.create( - model=self.model, + model=model_name, max_tokens=self.max_output, temperature=self.temperature, system = system, @@ -81,7 +86,7 @@ class Processor(LlmService): text = resp, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index df2c1143..a5b1deda 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -39,18 +39,23 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.temperature = temperature self.cohere = cohere.Client(api_key=api_key) logger.info("Cohere LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") try: - output = self.cohere.chat( - model=self.model, + output = self.cohere.chat( + model=model_name, message=prompt, preamble = system, temperature=self.temperature, @@ -71,7 +76,7 @@ class Processor(LlmService): text = resp, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py index 6170490a..c1814129 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -53,10 +53,13 @@ class Processor(LlmService): ) self.client = genai.Client(api_key=api_key) - self.model = model + self.default_model = model self.temperature = temperature self.max_output = max_output + # Cache for generation configs per model + self.generation_configs = {} + block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH self.safety_settings = [ @@ -83,22 +86,36 @@ class Processor(LlmService): logger.info("GoogleAIStudio LLM service initialized") - async def generate_content(self, system, prompt): + def _get_or_create_config(self, model_name): + """Get cached generation config or create new one""" + if model_name not in self.generation_configs: + logger.info(f"Creating generation config for '{model_name}'") + self.generation_configs[model_name] = types.GenerateContentConfig( + temperature = self.temperature, + top_p = 1, + top_k = 40, + max_output_tokens = self.max_output, + response_mime_type = "text/plain", + safety_settings = self.safety_settings, + ) - generation_config = types.GenerateContentConfig( - temperature = self.temperature, - top_p = 1, - top_k = 40, - max_output_tokens = self.max_output, - response_mime_type = "text/plain", - system_instruction = system, - safety_settings = self.safety_settings, - ) + return self.generation_configs[model_name] + + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") + + generation_config = self._get_or_create_config(model_name) + # Set system instruction per request (can't be cached) + generation_config.system_instruction = system try: response = self.client.models.generate_content( - model=self.model, + model=model_name, config=generation_config, contents=prompt, ) @@ -114,7 +131,7 @@ class Processor(LlmService): text = resp, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index d769248c..1571e3e7 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -39,7 +39,7 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.llamafile=llamafile self.temperature = temperature self.max_output = max_output @@ -50,14 +50,19 @@ class Processor(LlmService): logger.info("Llamafile LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt try: resp = self.openai.chat.completions.create( - model=self.model, + model=model_name, messages=[ {"role": "user", "content": prompt} ] @@ -82,7 +87,7 @@ class Processor(LlmService): text = resp.choices[0].message.content, in_token = inputtokens, out_token = outputtokens, - model = "llama.cpp", + model = model_name, ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py index 16dcfdda..0ed47517 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py @@ -39,7 +39,7 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.url = url + "v1/" self.temperature = temperature self.max_output = max_output @@ -50,7 +50,12 @@ class Processor(LlmService): logger.info("LMStudio LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt @@ -59,7 +64,7 @@ class Processor(LlmService): logger.debug(f"Prompt: {prompt}") resp = self.openai.chat.completions.create( - model=self.model, + model=model_name, messages=[ {"role": "user", "content": prompt} ] @@ -86,7 +91,7 @@ class Processor(LlmService): text = resp.choices[0].message.content, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py index 6dfd2656..c4ce26d5 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -41,21 +41,26 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.temperature = temperature self.max_output = max_output self.mistral = Mistral(api_key=api_key) logger.info("Mistral LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt try: resp = self.mistral.chat.complete( - model=self.model, + model=model_name, messages=[ { "role": "user", @@ -87,7 +92,7 @@ class Processor(LlmService): text = resp.choices[0].message.content, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 97ed7d15..fc19ace3 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -33,16 +33,21 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.llm = Client(host=ollama) - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt try: - response = self.llm.generate(self.model, prompt) + response = self.llm.generate(model_name, prompt) response_text = response['response'] logger.debug("Sending response...") @@ -55,7 +60,7 @@ class Processor(LlmService): text = response_text, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 8aa8c6b9..79e0c86f 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -47,7 +47,7 @@ class Processor(LlmService): } ) - self.model = model + self.default_model = model self.temperature = temperature self.max_output = max_output @@ -58,14 +58,19 @@ class Processor(LlmService): logger.info("OpenAI LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") prompt = system + "\n\n" + prompt try: resp = self.openai.chat.completions.create( - model=self.model, + model=model_name, messages=[ { "role": "user", @@ -97,7 +102,7 @@ class Processor(LlmService): text = resp.choices[0].message.content, in_token = inputtokens, out_token = outputtokens, - model = self.model + model = model_name ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py b/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py index 09286405..c8622c85 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/tgi/llm.py @@ -30,32 +30,40 @@ class Processor(LlmService): base_url = params.get("url", default_base_url) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + model = params.get("model", "tgi") super(Processor, self).__init__( **params | { "temperature": temperature, "max_output": max_output, "url": base_url, + "model": model, } ) self.base_url = base_url self.temperature = temperature self.max_output = max_output + self.default_model = model self.session = aiohttp.ClientSession() logger.info(f"Using TGI service at {base_url}") logger.info("TGI LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") headers = { "Content-Type": "application/json", } request = { - "model": "tgi", + "model": model_name, "messages": [ { "role": "system", @@ -96,7 +104,7 @@ class Processor(LlmService): text = ans, in_token = inputtokens, out_token = outputtokens, - model = "tgi", + model = model_name, ) return resp diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py index f194dc86..71b77d74 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py @@ -45,21 +45,26 @@ class Processor(LlmService): self.base_url = base_url self.temperature = temperature self.max_output = max_output - self.model = model + self.default_model = model self.session = aiohttp.ClientSession() logger.info(f"Using vLLM service at {base_url}") logger.info("vLLM LLM service initialized") - async def generate_content(self, system, prompt): + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") headers = { "Content-Type": "application/json", } request = { - "model": self.model, + "model": model_name, "prompt": system + "\n\n" + prompt, "max_tokens": self.max_output, "temperature": self.temperature, @@ -91,7 +96,7 @@ class Processor(LlmService): text = ans, in_token = inputtokens, out_token = outputtokens, - model = self.model, + model = model_name, ) return resp diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index a1ab4717..0ec9aca1 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -18,6 +18,7 @@ Supports both Google's Gemini models and Anthropic's Claude models. from google.oauth2 import service_account import google.auth +import google.api_core.exceptions import vertexai import logging @@ -59,8 +60,17 @@ class Processor(LlmService): super(Processor, self).__init__(**params) - self.model = model - self.is_anthropic = 'claude' in self.model.lower() + # Store default model and configuration parameters + self.default_model = model + self.region = region + self.temperature = temperature + self.max_output = max_output + self.private_key = private_key + + # Model client caches + self.model_clients = {} # Cache for model instances + self.generation_configs = {} # Cache for generation configs (Gemini only) + self.anthropic_client = None # Single Anthropic client (handles multiple models) # Shared parameters for both model types self.api_params = { @@ -89,71 +99,91 @@ class Processor(LlmService): "Ensure it's set in your environment or service account." ) - # Initialize the appropriate client based on the model type - if self.is_anthropic: - logger.info(f"Initializing Anthropic model '{model}' via AnthropicVertex SDK") - # Initialize AnthropicVertex with credentials if provided, otherwise use ADC - anthropic_kwargs = {'region': region, 'project_id': project_id} - if credentials and private_key: # Pass credentials only if from a file - anthropic_kwargs['credentials'] = credentials - logger.debug(f"Using service account credentials for Anthropic model") - else: - logger.debug(f"Using Application Default Credentials for Anthropic model") - - self.llm = AnthropicVertex(**anthropic_kwargs) - else: - # For Gemini models, initialize the Vertex AI SDK - logger.info(f"Initializing Google model '{model}' via Vertex AI SDK") - init_kwargs = {'location': region, 'project': project_id} - if credentials and private_key: # Pass credentials only if from a file - init_kwargs['credentials'] = credentials - - vertexai.init(**init_kwargs) + # Store credentials and project info for later use + self.credentials = credentials + self.project_id = project_id - self.llm = GenerativeModel(model) + # Initialize Vertex AI SDK for Gemini models + init_kwargs = {'location': region, 'project': project_id} + if credentials and private_key: # Pass credentials only if from a file + init_kwargs['credentials'] = credentials - self.generation_config = GenerationConfig( - temperature=temperature, - top_p=1.0, - top_k=10, - candidate_count=1, - max_output_tokens=max_output, - ) + vertexai.init(**init_kwargs) - # Block none doesn't seem to work - block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH - # block_level = HarmBlockThreshold.BLOCK_NONE - - self.safety_settings = [ - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold = block_level, - ), - ] + # Pre-initialize Anthropic client if needed (single client handles all Claude models) + if 'claude' in self.default_model.lower(): + self._get_anthropic_client() + # Safety settings for Gemini models + block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH + self.safety_settings = [ + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold = block_level, + ), + ] logger.info("VertexAI initialization complete") - async def generate_content(self, system, prompt): + def _get_anthropic_client(self): + """Get or create the Anthropic client (single client for all Claude models)""" + if self.anthropic_client is None: + logger.info(f"Initializing AnthropicVertex client") + anthropic_kwargs = {'region': self.region, 'project_id': self.project_id} + if self.credentials and self.private_key: # Pass credentials only if from a file + anthropic_kwargs['credentials'] = self.credentials + logger.debug(f"Using service account credentials for Anthropic models") + else: + logger.debug(f"Using Application Default Credentials for Anthropic models") + + self.anthropic_client = AnthropicVertex(**anthropic_kwargs) + + return self.anthropic_client + + def _get_gemini_model(self, model_name): + """Get or create a Gemini model instance""" + if model_name not in self.model_clients: + logger.info(f"Creating GenerativeModel instance for '{model_name}'") + self.model_clients[model_name] = GenerativeModel(model_name) + + # Create generation config for this model + self.generation_configs[model_name] = GenerationConfig( + temperature=self.temperature, + top_p=1.0, + top_k=10, + candidate_count=1, + max_output_tokens=self.max_output, + ) + + return self.model_clients[model_name], self.generation_configs[model_name] + + async def generate_content(self, system, prompt, model=None): + + # Use provided model or fall back to default + model_name = model or self.default_model + + logger.debug(f"Using model: {model_name}") try: - if self.is_anthropic: + if 'claude' in model_name.lower(): # Anthropic API uses a dedicated system prompt - logger.debug("Sending request to Anthropic model...") - response = self.llm.messages.create( - model=self.model, + logger.debug(f"Sending request to Anthropic model '{model_name}'...") + client = self._get_anthropic_client() + + response = client.messages.create( + model=model_name, system=system, messages=[{"role": "user", "content": prompt}], max_tokens=self.api_params['max_output_tokens'], @@ -166,15 +196,17 @@ class Processor(LlmService): text=response.content[0].text, in_token=response.usage.input_tokens, out_token=response.usage.output_tokens, - model=self.model + model=model_name ) else: # Gemini API combines system and user prompts - logger.debug("Sending request to Gemini model...") + logger.debug(f"Sending request to Gemini model '{model_name}'...") full_prompt = system + "\n\n" + prompt - response = self.llm.generate_content( - full_prompt, generation_config = self.generation_config, + llm, generation_config = self._get_gemini_model(model_name) + + response = llm.generate_content( + full_prompt, generation_config = generation_config, safety_settings = self.safety_settings, ) @@ -182,7 +214,7 @@ class Processor(LlmService): text = response.text, in_token = response.usage_metadata.prompt_token_count, out_token = response.usage_metadata.candidates_token_count, - model = self.model + model = model_name ) logger.info(f"Input Tokens: {resp.in_token}")