use passed in model name in chat completion request (#445)

This commit is contained in:
Adil Hafeez 2025-03-21 15:56:17 -07:00 committed by GitHub
parent bd8004d1ae
commit eb48f3d5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 364 additions and 89 deletions

View file

@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
)
def call_openai(messages: List[Dict[str, str]], stream: bool):
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
logger.info(f"llm agent model: {model}")
completion = openai_client.chat.completions.create(
model="None", # archgw picks the default LLM configured in the config file
model=model,
messages=messages,
stream=stream,
)
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
class Agent:
def __init__(self, role: str, instructions: str):
def __init__(self, role: str, instructions: str, model: str = ""):
self.model = model
self.system_prompt = f"You are a {role}.\n{instructions}"
def handle(self, req: ChatCompletionsRequest):
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
message.model_dump() for message in req.messages
]
return call_openai(messages, req.stream)
model = req.model
if self.model:
model = self.model
return call_openai(messages, req.stream, model)
def get_system_prompt(self) -> str:
return self.system_prompt
@ -77,13 +83,17 @@ AGENTS = {
"2. Quote ridiculous price\n"
"3. Reveal caveat if user agrees."
),
model="gpt-4o-mini",
),
"issues_and_repairs": Agent(
role="issues and repairs agent",
instructions="Propose a solution, offer refund if necessary.",
model="gpt-4o",
),
"escalate_to_human": Agent(
role="human escalation agent", instructions="Escalate issues to a human."
role="human escalation agent",
instructions="Escalate issues to a human.",
# skipping model name here as arch gateway will pick the default model from the config file
),
"unknown_agent": Agent(
role="general assistant", instructions="Assist the user in general queries."