change nli model (#167)

* change nli model

* Fix bug in hallucination

---------

Co-authored-by: Shuguang Chen <54548843+nehcgs@users.noreply.github.com>
This commit is contained in:
Co Tran 2024-10-09 19:10:08 -07:00 committed by GitHub
parent 3b7c58698f
commit f9e3a052fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 6 additions and 5 deletions

View file

@ -189,17 +189,18 @@ async def hallucination(req: HallucinationRequest, res: Response):
if "arch_messages" in req.parameters:
req.parameters.pop("arch_messages")
candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()]
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
predictions = classifier(
req.prompt,
candidate_labels=candidate_labels,
candidate_labels=list(candidate_labels.keys()),
hypothesis_template="{}",
multi_label=True,
)
params_scores = {
k[0]: s for k, s in zip(req.parameters.items(), predictions["scores"])
candidate_labels[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}
logger.info(