mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Flow temperature parameter (#533)
* Add temperature parameter to LlmService and roll out to all LLMs
This commit is contained in:
parent
aa8e422e8c
commit
6f4f7ce6b4
15 changed files with 164 additions and 72 deletions
|
|
@ -152,29 +152,35 @@ class Processor(LlmService):
|
|||
|
||||
return self.anthropic_client
|
||||
|
||||
def _get_gemini_model(self, model_name):
|
||||
def _get_gemini_model(self, model_name, temperature=None):
|
||||
"""Get or create a Gemini model instance"""
|
||||
if model_name not in self.model_clients:
|
||||
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
|
||||
self.model_clients[model_name] = GenerativeModel(model_name)
|
||||
|
||||
# Create generation config for this model
|
||||
self.generation_configs[model_name] = GenerationConfig(
|
||||
temperature=self.temperature,
|
||||
top_p=1.0,
|
||||
top_k=10,
|
||||
candidate_count=1,
|
||||
max_output_tokens=self.max_output,
|
||||
)
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
return self.model_clients[model_name], self.generation_configs[model_name]
|
||||
# Create generation config with the effective temperature
|
||||
generation_config = GenerationConfig(
|
||||
temperature=effective_temperature,
|
||||
top_p=1.0,
|
||||
top_k=10,
|
||||
candidate_count=1,
|
||||
max_output_tokens=self.max_output,
|
||||
)
|
||||
|
||||
async def generate_content(self, system, prompt, model=None):
|
||||
return self.model_clients[model_name], generation_config
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
if 'claude' in model_name.lower():
|
||||
|
|
@ -187,7 +193,7 @@ class Processor(LlmService):
|
|||
system=system,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=self.api_params['max_output_tokens'],
|
||||
temperature=self.api_params['temperature'],
|
||||
temperature=effective_temperature,
|
||||
top_p=self.api_params['top_p'],
|
||||
top_k=self.api_params['top_k'],
|
||||
)
|
||||
|
|
@ -203,7 +209,7 @@ class Processor(LlmService):
|
|||
logger.debug(f"Sending request to Gemini model '{model_name}'...")
|
||||
full_prompt = system + "\n\n" + prompt
|
||||
|
||||
llm, generation_config = self._get_gemini_model(model_name)
|
||||
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
|
||||
|
||||
response = llm.generate_content(
|
||||
full_prompt, generation_config = generation_config,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue