mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more fixes
This commit is contained in:
parent
e35560623c
commit
608ef35af7
5 changed files with 69 additions and 38 deletions
|
|
@ -279,9 +279,9 @@ pub mod open_ai {
|
|||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(response: Option<String>, role: Option<String>) -> Self {
|
||||
pub fn new(response: Option<String>, role: Option<String>, model: Option<String>) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
model,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use crate::hallucination::extract_messages_for_hallucination;
|
|||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
|
|
@ -334,8 +334,16 @@ impl StreamContext {
|
|||
|
||||
let response_str = if self.streaming_response {
|
||||
let chunks = [
|
||||
ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())),
|
||||
ChatCompletionStreamResponse::new(Some(response), None),
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(response),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
),
|
||||
];
|
||||
|
||||
let mut response_str = String::new();
|
||||
|
|
@ -961,37 +969,28 @@ impl StreamContext {
|
|||
let chat_completion_response =
|
||||
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
|
||||
|
||||
let chunk_role_message = ChatCompletionStreamResponse {
|
||||
model: Some(chat_completion_response.model.clone()),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: Some(USER_ROLE.to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let chunks = [
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(chat_completion_response.model.clone()),
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
chat_completion_response.choices[0].message.content.clone(),
|
||||
None,
|
||||
Some(chat_completion_response.model.clone()),
|
||||
),
|
||||
];
|
||||
|
||||
let chat_completion_stream_response = ChatCompletionStreamResponse {
|
||||
model: Some(chat_completion_response.model),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: chat_completion_response.choices[0].message.content.clone(),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let chunk_role = serde_json::to_string(&chunk_role_message).unwrap();
|
||||
let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap();
|
||||
format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data)
|
||||
} else {
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
}
|
||||
else {
|
||||
String::from_utf8(body).unwrap()
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -12,14 +12,12 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
|
|||
def get_data_chunks(stream, n=1):
|
||||
chunks = []
|
||||
for chunk in stream.iter_lines():
|
||||
print(chunk)
|
||||
if chunk:
|
||||
chunk = chunk.decode("utf-8")
|
||||
chunk_data_id = chunk[0:6]
|
||||
assert chunk_data_id == "data: "
|
||||
chunk_data = chunk[6:]
|
||||
chunk_data = chunk_data.strip()
|
||||
# chunk_data = chunk_data.replace("null", "None")
|
||||
chunks.append(chunk_data)
|
||||
if len(chunks) >= n:
|
||||
break
|
||||
|
|
|
|||
35
e2e_tests/poetry.lock
generated
35
e2e_tests/poetry.lock
generated
|
|
@ -464,6 +464,25 @@ pytest = ">=4.6"
|
|||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-sugar"
|
||||
version = "1.0.0"
|
||||
description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"},
|
||||
{file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=21.3"
|
||||
pytest = ">=6.2.0"
|
||||
termcolor = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "pre-commit"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
|
|
@ -526,6 +545,20 @@ files = [
|
|||
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.5.0"
|
||||
description = "ANSI color formatting for output in terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
|
||||
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.2"
|
||||
|
|
@ -637,4 +670,4 @@ h11 = ">=0.9.0,<1"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "9599d8855914d13d1762c6b30ed9526240baceda274af2b49d64835063fc9b29"
|
||||
content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ python = "^3.10"
|
|||
pytest = "^7.3.1"
|
||||
requests = "^2.29.0"
|
||||
selenium = "^4.11.2"
|
||||
pytest-sugar = "^1.0.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest-cov = "^4.1.0"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue