diff --git a/model_server/app/__init__.py b/model_server/app/__init__.py index f209ec07..59e9ab48 100644 --- a/model_server/app/__init__.py +++ b/model_server/app/__init__.py @@ -10,6 +10,7 @@ import tempfile # Path to the file where the server process ID will be stored PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid") + def run_server(): """Start, stop, or restart the Uvicorn server based on command-line arguments.""" if len(sys.argv) > 1: @@ -45,10 +46,11 @@ def start_server(): f.write(str(process.pid)) print(f"ARCH GW Model Server started with PID {process.pid}") else: - #Add model_server boot-up logs + # Add model_server boot-up logs print(f"ARCH GW Model Server - Didn't Sart In Time. Shutting Down") process.terminate() + def wait_for_health_check(url, timeout=180): """Wait for the Uvicorn server to respond to health-check requests.""" start_time = time.time() @@ -92,6 +94,7 @@ def stop_server(): process.kill() # Forcefully kill the process os.remove(PID_FILE) + def restart_server(): """Restart the Uvicorn server.""" print("Check: Is Archgw Model Server running?") diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py index e90b6531..ae3ad231 100644 --- a/model_server/app/arch_fc/arch_fc.py +++ b/model_server/app/arch_fc/arch_fc.py @@ -50,15 +50,16 @@ logger.info(f"serving mode: {mode}") logger.info(f"using model: {chosen_model}") logger.info(f"using endpoint: {endpoint}") + def process_state(arch_state, history: list[Message]): print("state: {}".format(arch_state)) state_json = json.loads(arch_state) state_map = {} if state_json: - for tools_state in state_json: - for tool_state in tools_state: - state_map[tool_state['key']] = tool_state + for tools_state in state_json: + for tool_state in tools_state: + state_map[tool_state["key"]] = tool_state print(f"state_map: {json.dumps(state_map)}") @@ -66,27 +67,38 @@ def process_state(arch_state, history: list[Message]): updated_history = [] for hist in history: updated_history.append({"role": hist.role, "content": hist.content}) - if hist.role == 'user': + if hist.role == "user": sha_history.append(hist.content) sha256_hash = hashlib.sha256() - joined_key_str = ('#.#').join(sha_history) + joined_key_str = ("#.#").join(sha_history) sha256_hash.update(joined_key_str.encode()) sha_key = sha256_hash.hexdigest() print(f"sha_key: {sha_key}") if sha_key in state_map: tool_call_state = state_map[sha_key] - if 'tool_call' in tool_call_state: - tool_call_str = json.dumps(tool_call_state['tool_call']) - updated_history.append({"role": "assistant", "content": f"\n{tool_call_str}\n"}) - if 'tool_response' in tool_call_state: - tool_resp = tool_call_state['tool_response'] - #TODO: try with role = user as well - updated_history.append({"role": "user", "content": f"\n{tool_resp}\n"}) + if "tool_call" in tool_call_state: + tool_call_str = json.dumps(tool_call_state["tool_call"]) + updated_history.append( + { + "role": "assistant", + "content": f"\n{tool_call_str}\n", + } + ) + if "tool_response" in tool_call_state: + tool_resp = tool_call_state["tool_response"] + # TODO: try with role = user as well + updated_history.append( + { + "role": "user", + "content": f"\n{tool_resp}\n", + } + ) # we dont want to match this state with any other messages - del(state_map[sha_key]) + del state_map[sha_key] return updated_history + async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") tools_encoded = handler._format_system(req.tools) @@ -98,7 +110,9 @@ async def chat_completion(req: ChatMessage, res: Response): for message in updated_history: messages.append({"role": message["role"], "content": message["content"]}) - logger.info(f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}") + logger.info( + f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}" + ) completions_params = params["params"] resp = client.chat.completions.create( messages=messages, diff --git a/model_server/app/arch_fc/arch_handler.py b/model_server/app/arch_fc/arch_handler.py index 35507d6f..8facf339 100644 --- a/model_server/app/arch_fc/arch_handler.py +++ b/model_server/app/arch_fc/arch_handler.py @@ -52,7 +52,6 @@ class ArchHandler: messages: list[dict], execution_results: list, ) -> dict: - content = [] for result in execution_results: content.append(f"\n{json.dumps(result)}\n") diff --git a/model_server/app/arch_fc/test_arch_fc.py b/model_server/app/arch_fc/test_arch_fc.py index caf5550c..fb94ad2b 100644 --- a/model_server/app/arch_fc/test_arch_fc.py +++ b/model_server/app/arch_fc/test_arch_fc.py @@ -2,16 +2,18 @@ import json import pytest from app.arch_fc.arch_fc import process_state from app.arch_fc.common import ChatMessage, Message + # test process_state arch_state = '[[{"key":"02ea8ec721b130dc30ec836b79ec675116cd5889bca7d63720bc64baed994fc1","message":{"role":"user","content":"how is the weather in new york?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"new york"}},"tool_response":"{\\"city\\":\\"new york\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":68,\\"max\\":79}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":70,\\"max\\":76}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":71,\\"max\\":84}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":61,\\"max\\":79}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":86,\\"max\\":91}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":85,\\"max\\":90}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":72,\\"max\\":89}}],\\"unit\\":\\"F\\"}"}],[{"key":"566b9a2197cba89f35c1e3fbeee55882772ae7627fcf4411dae90282f98a1067","message":{"role":"user","content":"how is the weather in chicago?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"chicago"}},"tool_response":"{\\"city\\":\\"chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":54,\\"max\\":64}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":84,\\"max\\":99}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":85,\\"max\\":100}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":50,\\"max\\":62}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":79,\\"max\\":85}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":88,\\"max\\":100}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":56,\\"max\\":61}}],\\"unit\\":\\"F\\"}"}]]' + def test_process_state(): - history = [] - history.append(Message(role="user", content="how is the weather in new york?")) - history.append(Message(role="user", content="how is the weather in chicago?")) - updated_history = process_state(arch_state, history) - print(json.dumps(updated_history, indent=2)) + history = [] + history.append(Message(role="user", content="how is the weather in new york?")) + history.append(Message(role="user", content="how is the weather in chicago?")) + updated_history = process_state(arch_state, history) + print(json.dumps(updated_history, indent=2)) if __name__ == "__main__": diff --git a/model_server/app/load_models.py b/model_server/app/load_models.py index 0947c427..628b155f 100644 --- a/model_server/app/load_models.py +++ b/model_server/app/load_models.py @@ -4,6 +4,7 @@ from transformers import AutoTokenizer, pipeline import sqlite3 import torch + def get_device(): if torch.cuda.is_available(): device = "cuda" @@ -14,14 +15,18 @@ def get_device(): return device + def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")): transformers = {} device = get_device() for model in models.split(","): - transformers[model] = sentence_transformers.SentenceTransformer(model, device=device) + transformers[model] = sentence_transformers.SentenceTransformer( + model, device=device + ) return transformers + def load_guard_model( model_name, hardware_config="cpu", @@ -57,9 +62,12 @@ def load_zero_shot_models( zero_shot_models = {} device = get_device() for model in models.split(","): - zero_shot_models[model] = pipeline("zero-shot-classification", model=model, device=device) + zero_shot_models[model] = pipeline( + "zero-shot-classification", model=model, device=device + ) return zero_shot_models -if __name__ =="__main__": + +if __name__ == "__main__": print(get_device()) diff --git a/model_server/app/main.py b/model_server/app/main.py index 3161557d..90066e7e 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -5,7 +5,7 @@ from app.load_models import ( load_transformers, load_guard_model, load_zero_shot_models, - get_device + get_device, ) import os from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config @@ -21,17 +21,17 @@ logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) - +logger.info("Device used: " + get_device()) transformers = load_transformers() zero_shot_models = load_zero_shot_models() guard_model_config = load_yaml_config("guard_model_config.yaml") mode = os.getenv("MODE", "cloud") logger.info(f"Serving model mode: {mode}") -if mode not in ['cloud', 'local-gpu', 'local-cpu']: +if mode not in ["cloud", "local-gpu", "local-cpu"]: raise ValueError(f"Invalid mode: {mode}") -if mode == 'local-cpu': - hardware = 'cpu' +if mode == "local-cpu": + hardware = "cpu" else: hardware = "gpu" if torch.cuda.is_available() else "cpu" @@ -40,6 +40,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model) app = FastAPI() + class EmbeddingRequest(BaseModel): input: str model: str @@ -49,6 +50,7 @@ class EmbeddingRequest(BaseModel): async def healthz(): return {"status": "ok"} + @app.get("/models") async def models(): models = [] @@ -61,12 +63,11 @@ async def models(): @app.post("/embeddings") async def embedding(req: EmbeddingRequest, res: Response): - print(f"Embedding Call Start Time: {time.time()}") if req.model not in transformers: raise HTTPException(status_code=400, detail="unknown model: " + req.model) - + start = time.time() embeddings = transformers[req.model].encode([req.input]) - + print(f"Embedding Call Complete Time: {time.time()-start}") data = [] for embedding in embeddings.tolist(): @@ -76,7 +77,7 @@ async def embedding(req: EmbeddingRequest, res: Response): "prompt_tokens": 0, "total_tokens": 0, } - print(f"Embedding Call Complete Time: {time.time()}") + return {"data": data, "model": req.model, "object": "list", "usage": usage} @@ -197,10 +198,10 @@ class HallucinationRequest(BaseModel): @app.post("/hallucination") async def hallucination(req: HallucinationRequest, res: Response): """ - Hallucination API, take input as text and return the prediction of hallucination for each parameter - parameters: dictionary of parameters and values - example {"name": "John", "age": "25"} - prompt: input prompt from the user + Hallucination API, take input as text and return the prediction of hallucination for each parameter + parameters: dictionary of parameters and values + example {"name": "John", "age": "25"} + prompt: input prompt from the user """ if req.model not in zero_shot_models: raise HTTPException(status_code=400, detail="unknown model: " + req.model) @@ -209,9 +210,12 @@ async def hallucination(req: HallucinationRequest, res: Response): candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()] hypothesis_template = "{}" result = classifier( - req.prompt, candidate_labels=candidate_labels, hypothesis_template=hypothesis_template, multi_label=True + req.prompt, + candidate_labels=candidate_labels, + hypothesis_template=hypothesis_template, + multi_label=True, ) - result_score = result['scores'] + result_score = result["scores"] result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)} return { diff --git a/model_server/app/utils.py b/model_server/app/utils.py index 49fd4592..d7d9d8e0 100644 --- a/model_server/app/utils.py +++ b/model_server/app/utils.py @@ -5,10 +5,11 @@ import torch import pkg_resources import yaml + def load_yaml_config(file_name): # Load the YAML file from the package - yaml_path = pkg_resources.resource_filename('app', file_name) - with open(yaml_path, 'r') as yaml_file: + yaml_path = pkg_resources.resource_filename("app", file_name) + with open(yaml_path, "r") as yaml_file: return yaml.safe_load(yaml_file) @@ -29,6 +30,7 @@ def split_text_into_chunks(text, max_words=300): def softmax(x): return np.exp(x) / np.exp(x).sum(axis=0) + class PredictionHandler: def __init__(self, model, tokenizer, device, task="toxic", hardware_config="cpu"): self.model = model diff --git a/model_server/setup.py b/model_server/setup.py index ec6d4fe7..b6883015 100644 --- a/model_server/setup.py +++ b/model_server/setup.py @@ -1,12 +1,16 @@ from setuptools import setup, find_packages + # Function to read requirements.txt def parse_requirements(filename): - with open(filename, 'r') as file: - return [line.strip() for line in file if line.strip() and not line.startswith("#")] + with open(filename, "r") as file: + return [ + line.strip() for line in file if line.strip() and not line.startswith("#") + ] + # Call the parse_requirements function to get the list of dependencies -requirements = parse_requirements('requirements.txt') +requirements = parse_requirements("requirements.txt") print(f"packages to install: {find_packages()}") setup( @@ -16,11 +20,11 @@ setup( install_requires=requirements, package_data={ # Specify the package and the data files you want to include - 'app': ['/*.yaml'], # Includes all .yaml files in the config/ folder + "app": ["/*.yaml"], # Includes all .yaml files in the config/ folder }, entry_points={ - 'console_scripts': [ - 'model_server=app:run_server', + "console_scripts": [ + "model_server=app:run_server", ], }, )