Flow temperature parameter (#533)

* Add temperature parameter to LlmService and roll out to all LLMs
This commit is contained in:
cybermaggedon 2025-09-25 21:26:11 +01:00 committed by GitHub
parent aa8e422e8c
commit 6f4f7ce6b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 164 additions and 72 deletions

View file

@ -152,29 +152,35 @@ class Processor(LlmService):
return self.anthropic_client
def _get_gemini_model(self, model_name):
def _get_gemini_model(self, model_name, temperature=None):
"""Get or create a Gemini model instance"""
if model_name not in self.model_clients:
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
self.model_clients[model_name] = GenerativeModel(model_name)
# Create generation config for this model
self.generation_configs[model_name] = GenerationConfig(
temperature=self.temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=self.max_output,
)
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
return self.model_clients[model_name], self.generation_configs[model_name]
# Create generation config with the effective temperature
generation_config = GenerationConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=self.max_output,
)
async def generate_content(self, system, prompt, model=None):
return self.model_clients[model_name], generation_config
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
if 'claude' in model_name.lower():
@ -187,7 +193,7 @@ class Processor(LlmService):
system=system,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.api_params['max_output_tokens'],
temperature=self.api_params['temperature'],
temperature=effective_temperature,
top_p=self.api_params['top_p'],
top_k=self.api_params['top_k'],
)
@ -203,7 +209,7 @@ class Processor(LlmService):
logger.debug(f"Sending request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
llm, generation_config = self._get_gemini_model(model_name)
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
response = llm.generate_content(
full_prompt, generation_config = generation_config,