diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 22d20095..e7448c7e 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -17,23 +17,23 @@ overrides: llm_providers: - name: gpt-4o-mini - access_key: OPENAI_API_KEY + access_key: $OPENAI_API_KEY provider: openai model: gpt-4o-mini default: true - name: gpt-3.5-turbo-0125 - access_key: OPENAI_API_KEY + access_key: $OPENAI_API_KEY provider: openai model: gpt-3.5-turbo-0125 - name: gpt-4o - access_key: OPENAI_API_KEY + access_key: $OPENAI_API_KEY provider: openai model: gpt-4o - name: ministral-3b - access_key: MISTRAL_API_KEY + access_key: $MISTRAL_API_KEY provider: mistral model: ministral-3b-latest diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 4c55816c..8fc49940 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -25,10 +25,16 @@ class ChatMessage(BaseModel): class Choice(BaseModel): message: Message + finish_reason: Optional[str] = "stop" + index: Optional[int] = 0 class ChatCompletionResponse(BaseModel): choices: List[Choice] + model: Optional[str] = "Arch-Function" + created: Optional[str] = "" + id: Optional[str] = "" + object: Optional[str] = "chat_completion" def process_messages(history: list[Message]): @@ -132,7 +138,9 @@ async def chat_completion(req: ChatMessage, res: Response): else: message = Message(content=full_response, tool_calls=[]) choice = Choice(message=message) - chat_completion_response = ChatCompletionResponse(choices=[choice]) + chat_completion_response = ChatCompletionResponse( + choices=[choice], model=client_model_name + ) logger.info( f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}" diff --git a/model_server/app/main.py b/model_server/app/main.py index a8d312d7..fdf091f0 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -215,5 +215,10 @@ async def hallucination(req: HallucinationRequest, res: Response): @app.post("/v1/chat/completions") async def chat_completion(req: ChatMessage, res: Response): - result = await arch_function_chat_completion(req, res) - return result + try: + result = await arch_function_chat_completion(req, res) + return result + except Exception as e: + logger.error(f"Error in chat_completion: {e}") + res.status_code = 500 + return {"error": "Internal server error"}