Some fixes on model server (#362)

* Some fixes on model server

* Remove prompt_prefilling message

* Fix logging

* Fix poetry issues

* Improve logging and update the support for text truncation

* Fix tests

* Fix tests

* Fix tests

* Fix modelserver tests

* Update modelserver tests
This commit is contained in:
Shuguang Chen 2025-01-10 16:45:36 -08:00 committed by GitHub
parent ebda682b30
commit 88a02dc478
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1090 additions and 1666 deletions

View file

@ -1,25 +1,25 @@
import json
import logging
import os
import time
import logging
import src.commons.utils as utils
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, GuardRequest
from src.core.utils.model_utils import (
ChatMessage,
ChatCompletionResponse,
GuardRequest,
GuardResponse,
)
from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
resource = Resource.create(
{
@ -31,11 +31,6 @@ resource = Resource.create(
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
@ -47,6 +42,16 @@ otlp_exporter = OTLPSpanExporter(
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
logger = utils.get_model_server_logger()
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
logging.ERROR
)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@ -62,73 +67,78 @@ async def models():
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
logger.info("[Endpoint: /function_calling]")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
final_response: ChatCompletionResponse = None
error_messages = None
try:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
if handler_map["Arch-Intent"].detect_intent(intent_response):
# [TODO] measure agreement between intent detection and function calling
# TODO: measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
function_calling_response = await handler_map[
"Arch-Function"
].chat_completion(req)
final_response = await handler_map["Arch-Function"].chat_completion(req)
function_latency = time.perf_counter() - function_start_time
function_calling_response.metadata = {
final_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
"hallucination": str(handler_map["Arch-Function"].hallucination),
"tokens_uncertainty": json.dumps(
handler_map["Arch-Function"].hallu_handler.token_probs_map
),
"prompt_prefilling": str(
handler_map["Arch-Function"].prompt_prefilling
"hallucination": str(
handler_map["Arch-Function"].hallucination_state.hallucination
),
}
return function_calling_response
except ValueError as e:
res.statuscode = 503
error_message = "Tool call extraction error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
except StopIteration as e:
res.statuscode = 500
error_message = "Hallucination iterator error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
else:
return {
"result": "No intent matched",
"intent_latency": round(intent_latency * 1000, 3),
intent_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
}
final_response = intent_response
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
logger.error(f"Error in chat_completion /function_calling: {e}")
res.status_code = 500
return {"error": f"[Arch-Intent] - {e}"}
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
if error_messages is not None:
logger.error(error_messages)
final_response = ChatCompletionResponse(metadata={"error": error_messages})
return final_response
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
logger.info("[Endpoint: /guardrails] - Gateway")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
final_response: GuardResponse = None
error_messages = None
try:
guard_start_time = time.perf_counter()
guard_result = handler_map["Arch-Guard"].predict(req)
final_response = handler_map["Arch-Guard"].predict(req)
guard_latency = time.perf_counter() - guard_start_time
return {
"response": guard_result,
final_response.metadata = {
"guard_latency": round(guard_latency * 1000, 3),
}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}
error_messages = f"[Arch-Guard]: {e}"
if error_messages is not None:
logger.error(error_messages)
final_response = GuardResponse(metadata={"error": error_messages})
return final_response