diff --git a/model_server/app/main.py b/model_server/app/main.py index 3c529f7a..b2736e7f 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -190,6 +190,40 @@ async def zeroshot(req: ZeroShotRequest, res: Response): "model": req.model, } + +class HallucinationRequest(BaseModel): + prompt: str + parameters: dict + model: str + + +@app.post("/hallucination") +async def hallucination(req: HallucinationRequest, res: Response): + """ + Hallucination API, take input as text and return the prediction of hallucination for each parameter + parameters: dictionary of parameters and values + example {"name": "John", "age": "25"} + prompt: input prompt from the user + """ + if req.model not in zero_shot_models: + raise HTTPException(status_code=400, detail="unknown model: " + req.model) + + classifier = zero_shot_models[req.model] + candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()] + hypothesis_template = "{}" + result = classifier( + req.prompt, candidate_labels=candidate_labels, hypothesis_template=hypothesis_template, multi_label=True + ) + result_score = result['scores'] + result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)} + + return { + "params_scores": result_params, + "raw_result": result, + "model": req.model, + } + + @app.post("/v1/chat/completions") async def chat_completion(req: ChatMessage, res: Response): result = await arch_fc_chat_completion(req, res) diff --git a/model_server/openai_params.yaml b/model_server/openai_params.yaml index 6a5f8b2f..ebaa0cb8 100644 --- a/model_server/openai_params.yaml +++ b/model_server/openai_params.yaml @@ -2,5 +2,5 @@ params: temperature: 0.01 top_p : 0.5 top_k: 50 - max_tokens: 512 + max_tokens: 2024 stop_token_ids: [151645, 151643]