diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index b9eb8fc7..10c8c603 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -1,109 +1,52 @@ +import os +from openai import OpenAI import gradio as gr -import asyncio -import httpx -import async_timeout +OPEN_API_KEY=os.getenv("OPENAI_API_KEY") +CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") -from loguru import logger -from typing import Optional, List -from pydantic import BaseModel -from dotenv import load_dotenv +client = OpenAI(api_key=OPEN_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) -import os -load_dotenv() +def predict(message, history): + # history_openai_format = [] + # for human, assistant in history: + # history_openai_format.append({"role": "user", "content": human }) + # history_openai_format.append({"role": "assistant", "content":assistant}) + history.append({"role": "user", "content": message}) -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1/chat/completions") - - -class Message(BaseModel): - - role: str - content: str - # model is additional state we maintin on client side so that bolt gateway can know which model responded to user prompt - model: str - resolver: str - -async def make_completion(messages:List[Message], nb_retries:int=3, delay:int=120) -> Optional[str]: - """ - Sends a request to the ChatGPT API to retrieve a response based on a list of previous messages. - """ - header = { - "Content-Type": "application/json", - } - - if OPENAI_API_KEY is not None and OPENAI_API_KEY != "": - header["Authorization"] = f"Bearer {OPENAI_API_KEY}" - - if OPENAI_API_KEY is None or OPENAI_API_KEY == "": - if CHAT_COMPLETION_ENDPOINT.startswith("https://api.openai.com"): - logger.error("No OpenAI API Key found. Please create .env file and set OPENAI_API_KEY env var !") - return None try: - async with async_timeout.timeout(delay=delay): - async with httpx.AsyncClient(headers=header) as aio_client: - counter = 0 - keep_loop = True - while keep_loop: - logger.debug(f"Chat/Completions Nb Retries : {counter}") - try: - resp = await aio_client.post( - url = CHAT_COMPLETION_ENDPOINT, - json = { - "model": "gpt-3.5-turbo", - "messages": messages - }, - timeout=delay - ) - logger.debug(f"Status Code : {resp.status_code}") - if resp.status_code == 200: - resp_json = resp.json() - model = resp_json["model"] - msg = {} - msg["role"] = "assistant" - msg["model"] = model - if "resolver_name" in resp_json: - msg["resolver"] = resp_json["resolver_name"] - if "choices" in resp_json: - msg["content"] = resp_json["choices"][0]["message"]["content"] - return msg - elif "message" in resp_json: - msg["content"] = resp_json["message"]["content"] - return msg - keep_loop = False - else: - logger.warning(resp.content) - keep_loop = False - except Exception as e: - logger.error(e) - counter = counter + 1 - keep_loop = counter < nb_retries - except asyncio.TimeoutError as e: - logger.error(f"Timeout {delay} seconds !") - return None + response = client.chat.completions.create(model='gpt-3.5-turbo', + messages= history, + temperature=1.0 + ) + except Exception as e: + print(e) + # remove last user message in case of exception + history.pop() + raise gr.Error("Error with OpenAI API: {}".format(e.message)) + + # for chunk in response: + # if chunk.choices[0].delta.content is not None: + # partial_message = partial_message + chunk.choices[0].delta.content + # yield partial_message + choices = response.choices + message = choices[0].message + content = message.content + history.append({"role": "assistant", "content": content}) + history[-1]["model"] = response.model -async def predict(input, history): - """ - Predict the response of the chatbot and complete a running list of chat history. - """ - history.append({"role": "user", "content": input}) - response = await make_completion(history) - print(response) - if response is not None: - history.append(response) messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] return messages, history -""" -Gradio Blocks low-level API that allows to create custom web applications (here our chat app) -""" -# with fill_height=true the chatbot to fill the height of the page + with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo: - logger.info("Starting Demo...") + print("Starting Demo...") chatbot = gr.Chatbot(label="Bolt Chatbot", scale=1) state = gr.State([]) with gr.Row(): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1) + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True) + txt.submit(predict, [txt, state], [chatbot, state]) -demo.launch(server_name="0.0.0.0", server_port=8080) +demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True) diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 880c1498..6e85180a 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -21,7 +21,7 @@ use public_types::common_types::open_ai::{ StreamOptions, }; use public_types::common_types::{ - BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, + BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; use public_types::configuration::{Overrides, PromptTarget, PromptType}; @@ -427,9 +427,9 @@ impl StreamContext { let body_str = String::from_utf8(body).unwrap(); debug!("function_resolver response str: {:?}", body_str); - let mut boltfc_response: BoltFCResponse = serde_json::from_str(&body_str).unwrap(); + let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap(); - let boltfc_response_str = boltfc_response.message.content.as_ref().unwrap(); + let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap(); let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) { Ok(fc_resp) => fc_resp, @@ -439,7 +439,6 @@ impl StreamContext { // Let's send the response back to the user to initalize lightweight dialog for parameter collection // add resolver name to the response so the client can send the response back to the correct resolver - boltfc_response.resolver_name = Some(callout_context.prompt_target_name.unwrap()); info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e); let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap(); self.send_http_response( @@ -826,7 +825,7 @@ impl HttpContext for StreamContext { } }; - self.response_tokens += chat_completions_response.usage.completions_tokens; + self.response_tokens += chat_completions_response.usage.completion_tokens; } debug!( diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index 6b3e80cd..afcac01d 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -9,7 +9,8 @@ use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; use public_types::common_types::{ - open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail, + open_ai::{ChatCompletionsResponse, Choice, Message, Usage}, + BoltFCToolsCall, IntOrString, ToolCallDetail, }; use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serial_test::serial; @@ -426,16 +427,20 @@ fn request_ratelimited() { tool_calls: tool_call_detail, }; - let bolt_fc_resp = BoltFCResponse { - model: String::from("test"), - message: Message { - role: String::from("system"), - content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), - model: None, + let bolt_fc_resp = ChatCompletionsResponse { + usage: Usage { + completion_tokens: 0, }, - done_reason: String::from("test"), - done: true, - resolver_name: None, + choices: vec![Choice { + finish_reason: "test".to_string(), + index: 0, + message: Message { + role: "system".to_string(), + content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), + model: None, + }, + }], + model: String::from("test"), }; let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap(); @@ -535,16 +540,20 @@ fn request_not_ratelimited() { tool_calls: tool_call_detail, }; - let bolt_fc_resp = BoltFCResponse { - model: String::from("test"), - message: Message { - role: String::from("system"), - content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), - model: None, + let bolt_fc_resp = ChatCompletionsResponse { + usage: Usage { + completion_tokens: 0, }, - done_reason: String::from("test"), - done: true, - resolver_name: None, + choices: vec![Choice { + finish_reason: "test".to_string(), + index: 0, + message: Message { + role: "system".to_string(), + content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), + model: None, + }, + }], + model: String::from("test"), }; let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap(); diff --git a/function_resolver/app/main.py b/function_resolver/app/main.py index 8296b128..b3625db2 100644 --- a/function_resolver/app/main.py +++ b/function_resolver/app/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, Response from bolt_handler import BoltHandler from common import ChatMessage import logging -from ollama import Client +from openai import OpenAI import os ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost") @@ -15,8 +15,12 @@ logger.info(f"using ollama endpoint: {ollama_endpoint}") app = FastAPI() handler = BoltHandler() -ollama_client = Client(host=ollama_endpoint) +client = OpenAI( + base_url='http://{}:11434/v1/'.format(ollama_endpoint), + # required but ignored + api_key='ollama', +) @app.get("/healthz") async def healthz(): @@ -35,6 +39,6 @@ async def chat_completion(req: ChatMessage, res: Response): ) messages.append({"role": "user", "content": req.messages[-1].content}) - resp = ollama_client.chat(messages=messages, model=ollama_model, stream=False) + resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False) logger.info(f"response: {resp}") return resp diff --git a/function_resolver/requirements.txt b/function_resolver/requirements.txt index 240ae61d..fc70ebb5 100644 --- a/function_resolver/requirements.txt +++ b/function_resolver/requirements.txt @@ -1,4 +1,4 @@ fastapi uvicorn -ollama PyYAML +openai diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 05a17fdc..ead35d10 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -61,15 +61,6 @@ pub struct ToolsDefinition { pub parameters: ToolParameters, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BoltFCResponse { - pub model: String, - pub message: open_ai::Message, - pub done_reason: String, - pub done: bool, - pub resolver_name: Option, -} - #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum IntOrString { @@ -118,24 +109,34 @@ pub mod open_ai { #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct Choice { + pub finish_reason: String, + pub index: usize, + pub message: Message, + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsResponse { pub usage: Usage, + pub choices: Vec, + pub model: String } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { - pub completions_tokens: usize, + pub completion_tokens: usize, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunkResponse { pub model: String, - pub choices: Vec, + pub choices: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct Choice { + pub struct ChunkChoice { pub delta: Delta, // TODO: could this be an enum? pub finish_reason: Option,