plano/model_server/src/main.py
2025-04-15 14:39:12 -07:00

158 lines
5 KiB
Python

import json
import os
import time
import logging
import src.commons.utils as utils
from src.commons.globals import ARCH_ENDPOINT, handler_map
from src.core.function_calling import ArchFunctionHandler
from src.core.utils.model_utils import (
ChatMessage,
ChatCompletionResponse,
GuardRequest,
GuardResponse,
)
from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
resource = Resource.create(
{
"service.name": "model-server",
}
)
# Initialize the tracer provider
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
logger = utils.get_model_server_logger()
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
logging.ERROR
)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
logger.info(f"using archfc endpoint: {ARCH_ENDPOINT}")
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@app.get("/models")
async def models():
return {
"object": "list",
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
}
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
logger.info("[Endpoint: /function_calling]")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: ChatCompletionResponse = None
error_messages = None
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
try:
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
model_handler: ArchFunctionHandler = handler_map[handler_name]
start_time = time.perf_counter()
final_response = await model_handler.chat_completion(req)
latency = time.perf_counter() - start_time
if not final_response.metadata:
final_response.metadata = {}
# Parameter gathering for detected intents
if final_response.choices[0].message.content:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
# Function Calling
elif final_response.choices[0].message.tool_calls:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
# No intent detected
else:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
raise
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
raise
except Exception as e:
res.status_code = 500
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
if error_messages is not None:
logger.error(error_messages)
final_response = ChatCompletionResponse(metadata={"error": error_messages})
return final_response
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
logger.info("[Endpoint: /guardrails] - Gateway")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: GuardResponse = None
error_messages = None
try:
guard_start_time = time.perf_counter()
final_response = handler_map["Arch-Guard"].predict(req)
guard_latency = time.perf_counter() - guard_start_time
final_response.metadata = {
"guard_latency": round(guard_latency * 1000, 3),
}
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Guard]: {e}"
if error_messages is not None:
logger.error(error_messages)
final_response = GuardResponse(metadata={"error": error_messages})
return final_response