mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
change nli model (#167)
* change nli model * Fix bug in hallucination --------- Co-authored-by: Shuguang Chen <54548843+nehcgs@users.noreply.github.com>
This commit is contained in:
parent
3b7c58698f
commit
f9e3a052fc
3 changed files with 6 additions and 5 deletions
|
|
@ -1,5 +1,5 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/deberta-base-nli";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def get_embedding_model(
|
|||
|
||||
|
||||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli"),
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
print("Loading Zero-shot Model...")
|
||||
|
||||
|
|
|
|||
|
|
@ -189,17 +189,18 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
if "arch_messages" in req.parameters:
|
||||
req.parameters.pop("arch_messages")
|
||||
|
||||
candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()]
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
req.prompt,
|
||||
candidate_labels=candidate_labels,
|
||||
candidate_labels=list(candidate_labels.keys()),
|
||||
hypothesis_template="{}",
|
||||
multi_label=True,
|
||||
)
|
||||
|
||||
params_scores = {
|
||||
k[0]: s for k, s in zip(req.parameters.items(), predictions["scores"])
|
||||
candidate_labels[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue