more fixes

This commit is contained in:
Adil Hafeez 2024-10-27 15:18:20 -07:00
parent e35560623c
commit 608ef35af7
5 changed files with 69 additions and 38 deletions

View file

@ -279,9 +279,9 @@ pub mod open_ai {
}
impl ChatCompletionStreamResponse {
pub fn new(response: Option<String>, role: Option<String>) -> Self {
pub fn new(response: Option<String>, role: Option<String>, model: Option<String>) -> Self {
ChatCompletionStreamResponse {
model: Some(ARCH_FC_MODEL_NAME.to_string()),
model,
choices: vec![ChunkChoice {
delta: Delta {
role,

View file

@ -3,7 +3,7 @@ use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
};
use common::common_types::{
@ -334,8 +334,16 @@ impl StreamContext {
let response_str = if self.streaming_response {
let chunks = [
ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())),
ChatCompletionStreamResponse::new(Some(response), None),
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
),
ChatCompletionStreamResponse::new(
Some(response),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
),
];
let mut response_str = String::new();
@ -961,37 +969,28 @@ impl StreamContext {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
let chunk_role_message = ChatCompletionStreamResponse {
model: Some(chat_completion_response.model.clone()),
choices: vec![ChunkChoice {
delta: Delta {
role: Some(USER_ROLE.to_string()),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
};
let chunks = [
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(chat_completion_response.model.clone()),
),
ChatCompletionStreamResponse::new(
chat_completion_response.choices[0].message.content.clone(),
None,
Some(chat_completion_response.model.clone()),
),
];
let chat_completion_stream_response = ChatCompletionStreamResponse {
model: Some(chat_completion_response.model),
choices: vec![ChunkChoice {
delta: Delta {
role: None,
content: chat_completion_response.choices[0].message.content.clone(),
tool_calls: None,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
};
let chunk_role = serde_json::to_string(&chunk_role_message).unwrap();
let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap();
format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data)
} else {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
else {
String::from_utf8(body).unwrap()
};

View file

@ -12,14 +12,12 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
def get_data_chunks(stream, n=1):
chunks = []
for chunk in stream.iter_lines():
print(chunk)
if chunk:
chunk = chunk.decode("utf-8")
chunk_data_id = chunk[0:6]
assert chunk_data_id == "data: "
chunk_data = chunk[6:]
chunk_data = chunk_data.strip()
# chunk_data = chunk_data.replace("null", "None")
chunks.append(chunk_data)
if len(chunks) >= n:
break

35
e2e_tests/poetry.lock generated
View file

@ -464,6 +464,25 @@ pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
[[package]]
name = "pytest-sugar"
version = "1.0.0"
description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)."
optional = false
python-versions = "*"
files = [
{file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"},
{file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"},
]
[package.dependencies]
packaging = ">=21.3"
pytest = ">=6.2.0"
termcolor = ">=2.1.0"
[package.extras]
dev = ["black", "flake8", "pre-commit"]
[[package]]
name = "requests"
version = "2.32.3"
@ -526,6 +545,20 @@ files = [
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
]
[[package]]
name = "termcolor"
version = "2.5.0"
description = "ANSI color formatting for output in terminal"
optional = false
python-versions = ">=3.9"
files = [
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
]
[package.extras]
tests = ["pytest", "pytest-cov"]
[[package]]
name = "tomli"
version = "2.0.2"
@ -637,4 +670,4 @@ h11 = ">=0.9.0,<1"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "9599d8855914d13d1762c6b30ed9526240baceda274af2b49d64835063fc9b29"
content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851"

View file

@ -12,6 +12,7 @@ python = "^3.10"
pytest = "^7.3.1"
requests = "^2.29.0"
selenium = "^4.11.2"
pytest-sugar = "^1.0.0"
[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"