diff --git a/arch/src/consts.rs b/arch/src/consts.rs index a3e8e428..32172002 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -1,5 +1,5 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5"; -pub const DEFAULT_INTENT_MODEL: &str = "katanemo/deberta-base-nli"; +pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli"; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; diff --git a/model_server/app/loader.py b/model_server/app/loader.py index 0712fce1..0994bf54 100644 --- a/model_server/app/loader.py +++ b/model_server/app/loader.py @@ -30,7 +30,7 @@ def get_embedding_model( def get_zero_shot_model( - model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli"), + model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"), ): print("Loading Zero-shot Model...") diff --git a/model_server/app/main.py b/model_server/app/main.py index 4d2faafd..82bbeb50 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -189,17 +189,18 @@ async def hallucination(req: HallucinationRequest, res: Response): if "arch_messages" in req.parameters: req.parameters.pop("arch_messages") - candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()] + candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()} predictions = classifier( req.prompt, - candidate_labels=candidate_labels, + candidate_labels=list(candidate_labels.keys()), hypothesis_template="{}", multi_label=True, ) params_scores = { - k[0]: s for k, s in zip(req.parameters.items(), predictions["scores"]) + candidate_labels[label]: score + for label, score in zip(predictions["labels"], predictions["scores"]) } logger.info(