mirror of
https://github.com/katanemo/plano.git
synced 2026-04-27 09:46:28 +02:00
Some fixes on model server (#362)
* Some fixes on model server * Remove prompt_prefilling message * Fix logging * Fix poetry issues * Improve logging and update the support for text truncation * Fix tests * Fix tests * Fix tests * Fix modelserver tests * Update modelserver tests
This commit is contained in:
parent
ebda682b30
commit
88a02dc478
25 changed files with 1090 additions and 1666 deletions
|
|
@ -1,25 +1,25 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import src.commons.utils as utils
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, GuardRequest
|
||||
from src.core.utils.model_utils import (
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
GuardRequest,
|
||||
GuardResponse,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
|
@ -31,11 +31,6 @@ resource = Resource.create(
|
|||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
# DEFAULT_OTLP_HOST = "http://localhost:4317"
|
||||
DEFAULT_OTLP_HOST = "none"
|
||||
|
||||
|
|
@ -47,6 +42,16 @@ otlp_exporter = OTLPSpanExporter(
|
|||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
|
||||
logging.ERROR
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
|
@ -62,73 +67,78 @@ async def models():
|
|||
|
||||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
logger.info("[Endpoint: /function_calling]")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
|
||||
final_response: ChatCompletionResponse = None
|
||||
error_messages = None
|
||||
|
||||
try:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
# [TODO] measure agreement between intent detection and function calling
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
function_calling_response = await handler_map[
|
||||
"Arch-Function"
|
||||
].chat_completion(req)
|
||||
final_response = await handler_map["Arch-Function"].chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
function_calling_response.metadata = {
|
||||
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
"hallucination": str(handler_map["Arch-Function"].hallucination),
|
||||
"tokens_uncertainty": json.dumps(
|
||||
handler_map["Arch-Function"].hallu_handler.token_probs_map
|
||||
),
|
||||
"prompt_prefilling": str(
|
||||
handler_map["Arch-Function"].prompt_prefilling
|
||||
"hallucination": str(
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
),
|
||||
}
|
||||
|
||||
return function_calling_response
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_message = "Tool call extraction error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_message = "Hallucination iterator error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Function] - {e}"}
|
||||
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
|
||||
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
|
||||
else:
|
||||
return {
|
||||
"result": "No intent matched",
|
||||
"intent_latency": round(intent_latency * 1000, 3),
|
||||
intent_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
}
|
||||
final_response = intent_response
|
||||
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
|
||||
logger.error(f"Error in chat_completion /function_calling: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Intent] - {e}"}
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
final_response = ChatCompletionResponse(metadata={"error": error_messages})
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
logger.info("[Endpoint: /guardrails] - Gateway")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
|
||||
final_response: GuardResponse = None
|
||||
error_messages = None
|
||||
|
||||
try:
|
||||
guard_start_time = time.perf_counter()
|
||||
guard_result = handler_map["Arch-Guard"].predict(req)
|
||||
final_response = handler_map["Arch-Guard"].predict(req)
|
||||
guard_latency = time.perf_counter() - guard_start_time
|
||||
return {
|
||||
"response": guard_result,
|
||||
final_response.metadata = {
|
||||
"guard_latency": round(guard_latency * 1000, 3),
|
||||
}
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Guard] - {e}"}
|
||||
error_messages = f"[Arch-Guard]: {e}"
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
final_response = GuardResponse(metadata={"error": error_messages})
|
||||
|
||||
return final_response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue