fix some bugs

This commit is contained in:
Adil Hafeez 2024-10-28 20:55:58 -07:00
parent 662a840ac5
commit eaa99259ad
4 changed files with 20 additions and 15 deletions

View file

@ -87,7 +87,7 @@ def get_prompt_targets():
return None
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
@ -119,19 +119,19 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
if STREAM_RESPONSE:
response = raw_response.parse()
history.append({"role": "assistant", "content": "", "model": ""})
messages.append((query, ""))
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
for chunk in response:
print("chunk: " + str(chunk.to_dict()))
if len(chunk.choices) > 0:
if chunk.choices[0].delta.role:
print("role (hist): " + chunk.choices[0].delta.role)
print("role (resp): " + chunk.choices[0].delta.role)
if history[-1]["role"] != chunk.choices[0].delta.role:
print("creating new history item: " + str(chunk.choices[0]))
history.append(
{
"role": chunk.choices[0].delta.role,
@ -151,11 +151,12 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
if chunk.choices[0].delta.tool_calls:
history[-1]["tool_calls"] = chunk.choices[0].delta.tool_calls
if chunk.model and chunk.choices[0].delta.content:
messages[-1] = (
messages[-1][0],
messages[-1][1] + chunk.choices[0].delta.content,
)
if history[-1]["role"] != "tool":
if chunk.model and chunk.choices[0].delta.content:
messages[-1] = (
messages[-1][0],
messages[-1][1] + chunk.choices[0].delta.content,
)
yield "", messages, state
else:
log.error(f"raw_response: {raw_response.text}")

View file

@ -900,7 +900,11 @@ impl StreamContext {
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
if m.role == TOOL_ROLE || m.content.is_none() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());

View file

@ -71,7 +71,7 @@ class DefaultTargetRequest(BaseModel):
@app.post("/default_target")
async def default_target(req: DefaultTargetRequest, res: Response):
logger.info(f"Received arch_messages: {req.messages}")
logger.info(f"Received messages: {req.messages}")
resp = {
"choices": [
{

View file

@ -186,8 +186,8 @@ async def hallucination(req: HallucinationRequest, res: Response):
start_time = time.perf_counter()
classifier = zero_shot_model["pipeline"]
if "arch_messages" in req.parameters:
req.parameters.pop("arch_messages")
if "messages" in req.parameters:
req.parameters.pop("messages")
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}