mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix some bugs
This commit is contained in:
parent
662a840ac5
commit
eaa99259ad
4 changed files with 20 additions and 15 deletions
|
|
@ -87,7 +87,7 @@ def get_prompt_targets():
|
|||
return None
|
||||
|
||||
|
||||
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
|
||||
def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state):
|
||||
if "history" not in state:
|
||||
state["history"] = []
|
||||
|
||||
|
|
@ -119,19 +119,19 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
|
|||
if STREAM_RESPONSE:
|
||||
response = raw_response.parse()
|
||||
history.append({"role": "assistant", "content": "", "model": ""})
|
||||
messages.append((query, ""))
|
||||
# for gradio UI we don't want to show raw tool calls and messages from developer application
|
||||
# so we're filtering those out
|
||||
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
|
||||
|
||||
messages = [
|
||||
(history_view[i]["content"], history_view[i + 1]["content"])
|
||||
for i in range(0, len(history_view) - 1, 2)
|
||||
]
|
||||
|
||||
for chunk in response:
|
||||
print("chunk: " + str(chunk.to_dict()))
|
||||
if len(chunk.choices) > 0:
|
||||
if chunk.choices[0].delta.role:
|
||||
print("role (hist): " + chunk.choices[0].delta.role)
|
||||
print("role (resp): " + chunk.choices[0].delta.role)
|
||||
if history[-1]["role"] != chunk.choices[0].delta.role:
|
||||
print("creating new history item: " + str(chunk.choices[0]))
|
||||
history.append(
|
||||
{
|
||||
"role": chunk.choices[0].delta.role,
|
||||
|
|
@ -151,11 +151,12 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
|
|||
if chunk.choices[0].delta.tool_calls:
|
||||
history[-1]["tool_calls"] = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if chunk.model and chunk.choices[0].delta.content:
|
||||
messages[-1] = (
|
||||
messages[-1][0],
|
||||
messages[-1][1] + chunk.choices[0].delta.content,
|
||||
)
|
||||
if history[-1]["role"] != "tool":
|
||||
if chunk.model and chunk.choices[0].delta.content:
|
||||
messages[-1] = (
|
||||
messages[-1][0],
|
||||
messages[-1][1] + chunk.choices[0].delta.content,
|
||||
)
|
||||
yield "", messages, state
|
||||
else:
|
||||
log.error(f"raw_response: {raw_response.text}")
|
||||
|
|
|
|||
|
|
@ -900,7 +900,11 @@ impl StreamContext {
|
|||
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
if m.role == TOOL_ROLE || m.content.is_none() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class DefaultTargetRequest(BaseModel):
|
|||
|
||||
@app.post("/default_target")
|
||||
async def default_target(req: DefaultTargetRequest, res: Response):
|
||||
logger.info(f"Received arch_messages: {req.messages}")
|
||||
logger.info(f"Received messages: {req.messages}")
|
||||
resp = {
|
||||
"choices": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -186,8 +186,8 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
start_time = time.perf_counter()
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
if "arch_messages" in req.parameters:
|
||||
req.parameters.pop("arch_messages")
|
||||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue