mirror of
https://github.com/katanemo/plano.git
synced 2026-05-05 22:02:43 +02:00
use passed in model name in chat completion request (#445)
This commit is contained in:
parent
bd8004d1ae
commit
eb48f3d5bb
20 changed files with 364 additions and 89 deletions
|
|
@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
|
|||
)
|
||||
|
||||
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool):
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
|
||||
logger.info(f"llm agent model: {model}")
|
||||
completion = openai_client.chat.completions.create(
|
||||
model="None", # archgw picks the default LLM configured in the config file
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
|
|||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, role: str, instructions: str):
|
||||
def __init__(self, role: str, instructions: str, model: str = ""):
|
||||
self.model = model
|
||||
self.system_prompt = f"You are a {role}.\n{instructions}"
|
||||
|
||||
def handle(self, req: ChatCompletionsRequest):
|
||||
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
|
||||
message.model_dump() for message in req.messages
|
||||
]
|
||||
return call_openai(messages, req.stream)
|
||||
|
||||
model = req.model
|
||||
if self.model:
|
||||
model = self.model
|
||||
return call_openai(messages, req.stream, model)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return self.system_prompt
|
||||
|
|
@ -77,13 +83,17 @@ AGENTS = {
|
|||
"2. Quote ridiculous price\n"
|
||||
"3. Reveal caveat if user agrees."
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
"issues_and_repairs": Agent(
|
||||
role="issues and repairs agent",
|
||||
instructions="Propose a solution, offer refund if necessary.",
|
||||
model="gpt-4o",
|
||||
),
|
||||
"escalate_to_human": Agent(
|
||||
role="human escalation agent", instructions="Escalate issues to a human."
|
||||
role="human escalation agent",
|
||||
instructions="Escalate issues to a human.",
|
||||
# skipping model name here as arch gateway will pick the default model from the config file
|
||||
),
|
||||
"unknown_agent": Agent(
|
||||
role="general assistant", instructions="Assist the user in general queries."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue