diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index 38e1afa8..d152e16e 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -279,9 +279,9 @@ pub mod open_ai { } impl ChatCompletionStreamResponse { - pub fn new(response: Option, role: Option) -> Self { + pub fn new(response: Option, role: Option, model: Option) -> Self { ChatCompletionStreamResponse { - model: Some(ARCH_FC_MODEL_NAME.to_string()), + model, choices: vec![ChunkChoice { delta: Delta { role, diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index bd5c13ff..80b58610 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,7 +3,7 @@ use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter, + ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall, ToolType, }; use common::common_types::{ @@ -334,8 +334,16 @@ impl StreamContext { let response_str = if self.streaming_response { let chunks = [ - ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())), - ChatCompletionStreamResponse::new(Some(response), None), + ChatCompletionStreamResponse::new( + None, + Some(ASSISTANT_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_owned()), + ), + ChatCompletionStreamResponse::new( + Some(response), + None, + Some(ARCH_FC_MODEL_NAME.to_owned()), + ), ]; let mut response_str = String::new(); @@ -961,37 +969,28 @@ impl StreamContext { let chat_completion_response = serde_json::from_slice::(&body).unwrap(); - let chunk_role_message = ChatCompletionStreamResponse { - model: Some(chat_completion_response.model.clone()), - choices: vec![ChunkChoice { - delta: Delta { - role: Some(USER_ROLE.to_string()), - content: None, - tool_calls: None, - model: None, - tool_call_id: None, - }, - finish_reason: None, - }], - }; + let chunks = [ + ChatCompletionStreamResponse::new( + None, + Some(ASSISTANT_ROLE.to_string()), + Some(chat_completion_response.model.clone()), + ), + ChatCompletionStreamResponse::new( + chat_completion_response.choices[0].message.content.clone(), + None, + Some(chat_completion_response.model.clone()), + ), + ]; - let chat_completion_stream_response = ChatCompletionStreamResponse { - model: Some(chat_completion_response.model), - choices: vec![ChunkChoice { - delta: Delta { - role: None, - content: chat_completion_response.choices[0].message.content.clone(), - tool_calls: None, - model: None, - tool_call_id: None, - }, - finish_reason: None, - }], - }; - let chunk_role = serde_json::to_string(&chunk_role_message).unwrap(); - let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap(); - format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data) - } else { + let mut response_str = String::new(); + for chunk in chunks.iter() { + response_str.push_str("data: "); + response_str.push_str(&serde_json::to_string(&chunk).unwrap()); + response_str.push_str("\n\n"); + } + response_str + } + else { String::from_utf8(body).unwrap() }; diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 5449fa89..cd9e53e3 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -12,14 +12,12 @@ LLM_GATEWAY_ENDPOINT = os.getenv( def get_data_chunks(stream, n=1): chunks = [] for chunk in stream.iter_lines(): - print(chunk) if chunk: chunk = chunk.decode("utf-8") chunk_data_id = chunk[0:6] assert chunk_data_id == "data: " chunk_data = chunk[6:] chunk_data = chunk_data.strip() - # chunk_data = chunk_data.replace("null", "None") chunks.append(chunk_data) if len(chunks) >= n: break diff --git a/e2e_tests/poetry.lock b/e2e_tests/poetry.lock index fe5f5eb4..cdacfd08 100644 --- a/e2e_tests/poetry.lock +++ b/e2e_tests/poetry.lock @@ -464,6 +464,25 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "pytest-sugar" +version = "1.0.0" +description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)." +optional = false +python-versions = "*" +files = [ + {file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"}, + {file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"}, +] + +[package.dependencies] +packaging = ">=21.3" +pytest = ">=6.2.0" +termcolor = ">=2.1.0" + +[package.extras] +dev = ["black", "flake8", "pre-commit"] + [[package]] name = "requests" version = "2.32.3" @@ -526,6 +545,20 @@ files = [ {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, ] +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "tomli" version = "2.0.2" @@ -637,4 +670,4 @@ h11 = ">=0.9.0,<1" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "9599d8855914d13d1762c6b30ed9526240baceda274af2b49d64835063fc9b29" +content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851" diff --git a/e2e_tests/pyproject.toml b/e2e_tests/pyproject.toml index 583b10e6..45cfbc25 100644 --- a/e2e_tests/pyproject.toml +++ b/e2e_tests/pyproject.toml @@ -12,6 +12,7 @@ python = "^3.10" pytest = "^7.3.1" requests = "^2.29.0" selenium = "^4.11.2" +pytest-sugar = "^1.0.0" [tool.poetry.dev-dependencies] pytest-cov = "^4.1.0"