mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Use model in Azure LLM integration
This commit is contained in:
parent
e19ea8667d
commit
ae27eccbed
1 changed files with 16 additions and 5 deletions
|
|
@ -55,11 +55,13 @@ class Processor(LlmService):
|
||||||
self.max_output = max_output
|
self.max_output = max_output
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
|
|
||||||
def build_prompt(self, system, content, temperature=None, stream=False):
|
def build_prompt(self, system, content, temperature=None, stream=False, model=None):
|
||||||
# Use provided temperature or fall back to default
|
# Use provided temperature or fall back to default
|
||||||
effective_temperature = temperature if temperature is not None else self.temperature
|
effective_temperature = temperature if temperature is not None else self.temperature
|
||||||
|
model_name = model or self.default_model
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
"model": model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "system", "content": system
|
"role": "system", "content": system
|
||||||
|
|
@ -100,7 +102,8 @@ class Processor(LlmService):
|
||||||
raise TooManyRequests()
|
raise TooManyRequests()
|
||||||
|
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
raise RuntimeError("LLM failure")
|
logger.error(f"Azure API error: status={resp.status_code}, body={resp.text}")
|
||||||
|
raise RuntimeError(f"LLM failure: HTTP {resp.status_code}")
|
||||||
|
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
|
|
@ -121,7 +124,8 @@ class Processor(LlmService):
|
||||||
prompt = self.build_prompt(
|
prompt = self.build_prompt(
|
||||||
system,
|
system,
|
||||||
prompt,
|
prompt,
|
||||||
effective_temperature
|
effective_temperature,
|
||||||
|
model=model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.call_llm(prompt)
|
response = self.call_llm(prompt)
|
||||||
|
|
@ -174,7 +178,7 @@ class Processor(LlmService):
|
||||||
logger.debug(f"Using temperature: {effective_temperature}")
|
logger.debug(f"Using temperature: {effective_temperature}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = self.build_prompt(system, prompt, effective_temperature, stream=True)
|
body = self.build_prompt(system, prompt, effective_temperature, stream=True, model=model_name)
|
||||||
|
|
||||||
url = self.endpoint
|
url = self.endpoint
|
||||||
api_key = self.token
|
api_key = self.token
|
||||||
|
|
@ -190,7 +194,8 @@ class Processor(LlmService):
|
||||||
raise TooManyRequests()
|
raise TooManyRequests()
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise RuntimeError("LLM failure")
|
logger.error(f"Azure API error: status={response.status_code}, body={response.text}")
|
||||||
|
raise RuntimeError(f"LLM failure: HTTP {response.status_code}")
|
||||||
|
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
total_output_tokens = 0
|
total_output_tokens = 0
|
||||||
|
|
@ -279,6 +284,12 @@ class Processor(LlmService):
|
||||||
help=f'LLM max output tokens (default: {default_max_output})'
|
help=f'LLM max output tokens (default: {default_max_output})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-m', '--model',
|
||||||
|
default=default_model,
|
||||||
|
help=f'LLM model name (default: {default_model})'
|
||||||
|
)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue