Merge branch 'main' of https://github.com/katanemo/arch into cotran/hallucination-fix

This commit is contained in:
cotran 2025-02-13 11:20:37 -08:00
commit d0496e7d91
116 changed files with 645 additions and 1092 deletions

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.2.0"
version = "0.2.1"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
license = "Apache 2.0"

View file

@ -15,7 +15,7 @@ logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://api.fc.archgw.com/v1")
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)

View file

@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler):
req.messages, req.tools, self.extra_instruction
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
model_response = self.client.chat.completions.create(
messages=messages,
@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler):
"""
logger.info("[Arch-Function] - ChatCompletion")
messages = self._process_messages(req.messages, req.tools)
messages = self._process_messages(
req.messages, req.tools, metadata=req.metadata
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(

View file

@ -105,7 +105,7 @@ class ArchGuardHanlder:
raise NotImplementedError(f"{req.task} is not supported!")
logger.info("[Arch-Guard] - Prediction")
logger.info(f"[request]: {req.input}")
logger.info(f"[request arch-guard]: {req.input}")
if len(req.input.split()) < max_num_words:
result = self._predict_text(req.task, req.input)

View file

@ -16,6 +16,7 @@ class Message(BaseModel):
class ChatMessage(BaseModel):
messages: List[Message] = []
tools: List[Dict[str, Any]] = []
metadata: Optional[Dict[str, str]] = {}
class Choice(BaseModel):
@ -123,6 +124,7 @@ class ArchBaseHandler:
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
max_tokens=4096,
metadata: Dict[str, str] = {},
):
"""
Processes a list of messages and formats them appropriately.
@ -157,7 +159,12 @@ class ArchBaseHandler:
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif role == "tool":
role = "user"
content = f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
if metadata.get("optimize_context_window", "false").lower() == "true":
content = f"<tool_response>\n\n</tool_response>"
else:
content = (
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})

View file

@ -4,7 +4,7 @@ import time
import logging
import src.commons.utils as utils
from src.commons.globals import handler_map
from src.commons.globals import ARCH_ENDPOINT, handler_map
from src.core.utils.model_utils import (
ChatMessage,
ChatCompletionResponse,
@ -51,6 +51,8 @@ logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
logger.info(f"using archfc endpoint: {ARCH_ENDPOINT}")
@app.get("/healthz")
async def healthz():