mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Use large github action machine to run e2e tests (#230)
This commit is contained in:
parent
bb882fb59b
commit
e462e393b1
30 changed files with 4725 additions and 441 deletions
6
.github/workflows/e2e_tests.yml
vendored
6
.github/workflows/e2e_tests.yml
vendored
|
|
@ -8,7 +8,8 @@ on:
|
|||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-latest-m
|
||||
# runs-on: gh-large-150gb-ssd
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
|
@ -29,4 +30,5 @@ jobs:
|
|||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
cd e2e_tests && bash run_e2e_tests.sh
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd e2e_tests && bash run_e2e_tests.sh
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ FROM envoyproxy/envoy:v1.31-latest as envoy
|
|||
#Build config generator, so that we have a single build image for both Rust and Python
|
||||
FROM python:3-slim as arch
|
||||
|
||||
RUN apt-get update && apt-get install -y gettext-base && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ services:
|
|||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
- ~/archgw_logs:/var/log/
|
||||
env_file:
|
||||
- stage.env
|
||||
- env.list
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:10000/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
# Define paths
|
||||
source_schema="../arch_config_schema.yaml"
|
||||
source_compose="../docker-compose.yaml"
|
||||
source_stage_env="../stage.env"
|
||||
destination_dir="config"
|
||||
|
||||
# Ensure the destination directory exists only if it doesn't already
|
||||
|
|
@ -15,7 +14,7 @@ fi
|
|||
# Copy the files
|
||||
cp "$source_schema" "$destination_dir/arch_config_schema.yaml"
|
||||
cp "$source_compose" "$destination_dir/docker-compose.yaml"
|
||||
cp "$source_stage_env" "$destination_dir/stage.env"
|
||||
touch "$destination_dir/env.list"
|
||||
|
||||
# Print success message
|
||||
echo "Files copied successfully!"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def validate_and_render_schema():
|
|||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(str(e))
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
|
||||
with open(ARCH_CONFIG_FILE, "r") as file:
|
||||
|
|
@ -73,7 +73,6 @@ def validate_and_render_schema():
|
|||
|
||||
print("updated clusters", inferred_clusters)
|
||||
|
||||
config_yaml = add_secret_key_to_llm_providers(config_yaml)
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ def main(ctx, version):
|
|||
click.echo(f"archgw cli version: {get_version()}")
|
||||
ctx.exit()
|
||||
|
||||
log.info(f"Starting archgw cli version: {get_version()}")
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
|
||||
click.echo(logo)
|
||||
|
|
@ -68,7 +70,7 @@ def main(ctx, version):
|
|||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
|
||||
help="Optional parameter to specify which service to build. Options are model_server, archgw",
|
||||
)
|
||||
def build(service):
|
||||
"""Build Arch from source. Must be in root of cloned repo."""
|
||||
|
|
@ -168,7 +170,7 @@ def up(file, path, service):
|
|||
arch_config_schema_file=arch_schema_config,
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(f"Exiting archgw up: {e}")
|
||||
log.info(f"Exiting archgw up: validation failed")
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Starging arch model server and arch gateway")
|
||||
|
|
@ -178,6 +180,12 @@ def up(file, path, service):
|
|||
env = os.environ.copy()
|
||||
# check if access_keys are preesnt in the config file
|
||||
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
|
||||
|
||||
# remove duplicates
|
||||
access_keys = set(access_keys)
|
||||
# remove the $ from the access_keys
|
||||
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
|
||||
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(
|
||||
|
|
@ -186,6 +194,7 @@ def up(file, path, service):
|
|||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
|
||||
print(f"app_env_file: {app_env_file}")
|
||||
if not os.path.exists(
|
||||
app_env_file
|
||||
): # check to see if the environment variables in the current environment or not
|
||||
|
|
@ -205,7 +214,7 @@ def up(file, path, service):
|
|||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
with open(
|
||||
pkg_resources.resource_filename(__name__, "../config/stage.env"), "w"
|
||||
pkg_resources.resource_filename(__name__, "../config/env.list"), "w"
|
||||
) as file:
|
||||
for key, value in env_stage.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
|
|
|||
3513
arch/tools/poetry.lock
generated
3513
arch/tools/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -12,7 +12,6 @@ include = [
|
|||
# Include package data (docker-compose.yaml and other files)[
|
||||
"config/docker-compose.yaml",
|
||||
"config/arch_config_schema.yaml",
|
||||
"config/stage.env"
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
|
@ -22,8 +21,8 @@ pydantic = "^2.9.2"
|
|||
click = "^8.1.7"
|
||||
jinja2 = "^3.1.4"
|
||||
jsonschema = "^4.23.0"
|
||||
setuptools = "75.2.0"
|
||||
archgw_modelserver = "0.0.4"
|
||||
setuptools = "75.3.0"
|
||||
archgw_modelserver = "0.0.5"
|
||||
huggingface_hub = "^0.26.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
|
|
|||
|
|
@ -16,10 +16,6 @@
|
|||
"name": "model_server",
|
||||
"path": "model_server"
|
||||
},
|
||||
{
|
||||
"name": "chatbot_ui",
|
||||
"path": "chatbot_ui"
|
||||
},
|
||||
{
|
||||
"name": "e2e_tests",
|
||||
"path": "e2e_tests"
|
||||
|
|
|
|||
|
|
@ -316,7 +316,14 @@ impl HttpContext for StreamContext {
|
|||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut data = serde_json::from_str(&body_utf8).unwrap();
|
||||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!("could not deserialize response: {}", e);
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
|
|
|
|||
|
|
@ -446,61 +446,89 @@ impl StreamContext {
|
|||
// it may be that arch fc is handling the conversation for parameter collection
|
||||
if arch_assistant {
|
||||
info!("arch fc is engaged in parameter collection");
|
||||
} else {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!(
|
||||
"default prompt target found, forwarding request to default prompt target"
|
||||
);
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
} else if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"archgw (low similarity score) => llm request: {}",
|
||||
llm_request_str
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
|
@ -873,42 +901,8 @@ impl StreamContext {
|
|||
"archgw <= api call response: {}",
|
||||
self.tool_call_response.as_ref().unwrap()
|
||||
);
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
|
||||
// add system prompt
|
||||
let system_prompt = match prompt_target.system_prompt.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt.clone()),
|
||||
};
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
let mut messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -960,6 +954,51 @@ impl StreamContext {
|
|||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
// add system prompt
|
||||
|
||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(prompt_target_name) => {
|
||||
let prompt_system_prompt = self
|
||||
.prompt_targets
|
||||
.get(prompt_target_name)
|
||||
.unwrap()
|
||||
.clone()
|
||||
.system_prompt;
|
||||
match prompt_system_prompt {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt),
|
||||
}
|
||||
}
|
||||
};
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
|
|
@ -24,7 +24,9 @@ endpoints:
|
|||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
|
||||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding,
|
||||
- when you get data in json format, offer some summary but don't be too verbose
|
||||
- be concise, to the point and do not over analyze the data
|
||||
|
||||
prompt_targets:
|
||||
- name: workforce
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ with open("workforce_data.json") as file:
|
|||
|
||||
|
||||
# Define the request model
|
||||
class WorkforceRequset(BaseModel):
|
||||
class WorkforceRequest(BaseModel):
|
||||
region: str
|
||||
staffing_type: str
|
||||
data_snapshot_days_ago: Optional[int] = None
|
||||
|
|
@ -74,7 +74,7 @@ def send_slack_message(request: SlackRequest):
|
|||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequset):
|
||||
def get_workforce(request: WorkforceRequest):
|
||||
"""
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
|
|
@ -90,7 +90,7 @@ def get_workforce(request: WorkforceRequset):
|
|||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
"satisfaction": f"Satisfaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ system_prompt: |
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listen:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ endpoints:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
@ -47,7 +47,7 @@ llm_providers:
|
|||
|
||||
- name: Mistral8x7b
|
||||
provider: mistral
|
||||
access_key: MISTRAL_API_KEY
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: mistral-8x7b
|
||||
|
||||
- name: MistralLocal7b
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ log() {
|
|||
echo "$timestamp: $message"
|
||||
}
|
||||
|
||||
print_disk_usage() {
|
||||
echo free disk space
|
||||
df -h | grep "/$"
|
||||
}
|
||||
|
||||
wait_for_healthz() {
|
||||
local healthz_url="$1"
|
||||
local timeout_seconds="${2:-30}" # Default timeout of 30 seconds
|
||||
|
|
@ -28,6 +33,8 @@ wait_for_healthz() {
|
|||
return 1
|
||||
fi
|
||||
|
||||
print_disk_usage
|
||||
|
||||
sleep $sleep_between
|
||||
done
|
||||
}
|
||||
|
|
|
|||
39
e2e_tests/poetry.lock
generated
39
e2e_tests/poetry.lock
generated
|
|
@ -455,13 +455,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.4"
|
||||
version = "8.3.3"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"},
|
||||
{file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"},
|
||||
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
|
||||
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -469,11 +469,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
|||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
pluggy = ">=1.5,<2"
|
||||
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
|
|
@ -493,6 +493,23 @@ pytest = ">=4.6"
|
|||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-retry"
|
||||
version = "1.6.3"
|
||||
description = "Adds the ability to retry flaky tests in CI environments"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pytest_retry-1.6.3-py3-none-any.whl", hash = "sha256:e96f7df77ee70b0838d1085f9c3b8b5b7d74bf8947a0baf32e2b8c71b27683c8"},
|
||||
{file = "pytest_retry-1.6.3.tar.gz", hash = "sha256:36ccfa11c8c8f9ddad5e20375182146d040c20c4a791745139c5a99ddf1b557d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "isort", "mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-sugar"
|
||||
version = "1.0.0"
|
||||
|
|
@ -535,13 +552,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
|||
|
||||
[[package]]
|
||||
name = "selenium"
|
||||
version = "4.25.0"
|
||||
version = "4.26.0"
|
||||
description = "Official Python bindings for Selenium WebDriver"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "selenium-4.25.0-py3-none-any.whl", hash = "sha256:3798d2d12b4a570bc5790163ba57fef10b2afee958bf1d80f2a3cf07c4141f33"},
|
||||
{file = "selenium-4.25.0.tar.gz", hash = "sha256:95d08d3b82fb353f3c474895154516604c7f0e6a9a565ae6498ef36c9bac6921"},
|
||||
{file = "selenium-4.26.0-py3-none-any.whl", hash = "sha256:48013f36e812de5b3948ef53d04e73f77bc923ee3e1d7d99eaf0618179081b99"},
|
||||
{file = "selenium-4.26.0.tar.gz", hash = "sha256:f0780f85f10310aa5d085b81e79d73d3c93b83d8de121d0400d543a50ee963e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -699,4 +716,4 @@ h11 = ">=0.9.0,<1"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce"
|
||||
content-hash = "a40015b90325879e50f82cca6a26a730d763cad26589671df798832d41c42db3"
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ package-mode = false
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pytest = "^7.3.1"
|
||||
pytest = "^8.3.3"
|
||||
requests = "^2.29.0"
|
||||
selenium = "^4.11.2"
|
||||
pytest-sugar = "^1.0.0"
|
||||
deepdiff = "^8.0.1"
|
||||
pytest-retry = "^1.6.3"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest-cov = "^4.1.0"
|
||||
|
|
@ -21,3 +22,6 @@ pytest-cov = "^4.1.0"
|
|||
[tool.pytest.ini_options]
|
||||
python_files = ["test*.py"]
|
||||
addopts = ["-v", "-s"]
|
||||
retries = 2
|
||||
retry_delay = 0.5
|
||||
cumulative_timing = false
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ set -e
|
|||
|
||||
. ./common_scripts.sh
|
||||
|
||||
print_disk_usage
|
||||
|
||||
mkdir -p ~/archgw_logs
|
||||
touch ~/archgw_logs/modelserver.log
|
||||
|
||||
print_debug() {
|
||||
log "Received signal to stop"
|
||||
log "Printing debug logs for model_server"
|
||||
|
|
@ -18,63 +23,62 @@ trap 'print_debug' INT TERM ERR
|
|||
|
||||
log starting > ../build.log
|
||||
|
||||
log building function_callling demo
|
||||
log ===============================
|
||||
log building and running function_callling demo
|
||||
log ===========================================
|
||||
cd ../demos/function_calling
|
||||
docker compose build -q
|
||||
|
||||
log starting the function_calling demo
|
||||
docker compose up -d
|
||||
docker compose up api_server --build -d
|
||||
cd -
|
||||
|
||||
log building model server
|
||||
log =====================
|
||||
print_disk_usage
|
||||
|
||||
log building and install model server
|
||||
log =================================
|
||||
cd ../model_server
|
||||
poetry install 2>&1 >> ../build.log
|
||||
log starting model server
|
||||
log =====================
|
||||
mkdir -p ~/archgw_logs
|
||||
touch ~/archgw_logs/modelserver.log
|
||||
poetry run archgw_modelserver restart &
|
||||
poetry install
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log building and installing archgw cli
|
||||
log ==================================
|
||||
cd ../arch/tools
|
||||
sh build_cli.sh
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log building docker image for arch gateway
|
||||
log ======================================
|
||||
cd ../
|
||||
archgw build
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log startup arch gateway with function calling demo
|
||||
cd ..
|
||||
tail -F ~/archgw_logs/modelserver.log &
|
||||
model_server_tail_pid=$!
|
||||
archgw down
|
||||
archgw up demos/function_calling/arch_config.yaml
|
||||
kill $model_server_tail_pid
|
||||
cd -
|
||||
|
||||
log building llm and prompt gateway rust modules
|
||||
log ============================================
|
||||
cd ../arch
|
||||
docker build -f Dockerfile .. -t katanemo/archgw -q
|
||||
log starting the arch gateway service
|
||||
log =================================
|
||||
docker compose -f docker-compose.e2e.yaml down
|
||||
log waiting for model service to be healthy
|
||||
wait_for_healthz "http://localhost:51000/healthz" 300
|
||||
kill $model_server_tail_pid
|
||||
docker compose -f docker-compose.e2e.yaml up -d
|
||||
log waiting for arch gateway service to be healthy
|
||||
wait_for_healthz "http://localhost:10000/healthz" 60
|
||||
log waiting for arch gateway service to be healthy
|
||||
cd -
|
||||
print_disk_usage
|
||||
|
||||
log running e2e tests
|
||||
log =================
|
||||
poetry install 2>&1 >> ../build.log
|
||||
poetry install
|
||||
poetry run pytest
|
||||
|
||||
log shutting down the arch gateway service
|
||||
log ======================================
|
||||
cd ../arch
|
||||
docker compose -f docker-compose.e2e.yaml stop 2>&1 >> ../build.log
|
||||
cd ../
|
||||
archgw down
|
||||
cd -
|
||||
|
||||
log shutting down the function_calling demo
|
||||
log =======================================
|
||||
cd ../demos/function_calling
|
||||
docker compose down 2>&1 >> ../build.log
|
||||
cd -
|
||||
|
||||
log shutting down the model server
|
||||
log ==============================
|
||||
cd ../model_server
|
||||
poetry run archgw_modelserver stop 2>&1 >> ../build.log
|
||||
docker compose down
|
||||
cd -
|
||||
|
|
|
|||
2
model_server/.vscode/launch.json
vendored
2
model_server/.vscode/launch.json
vendored
|
|
@ -9,7 +9,7 @@
|
|||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["app.main:app","--reload", "--port", "8000"]
|
||||
"args": ["app.main:app","--reload", "--port", "51000"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
|
@ -7,6 +8,15 @@ import tempfile
|
|||
import subprocess
|
||||
import logging
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("archgw_modelserver")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
|
|
@ -14,6 +24,7 @@ logging.basicConfig(
|
|||
|
||||
log = logging.getLogger("model_server.cli")
|
||||
log.setLevel(logging.INFO)
|
||||
log.info(f"model server version: {get_version()}")
|
||||
|
||||
|
||||
def run_server(port=51000):
|
||||
|
|
@ -37,8 +48,9 @@ def run_server(port=51000):
|
|||
def start_server(port=51000):
|
||||
"""Start the Uvicorn server"""
|
||||
log.info(
|
||||
"Starting model server - loading some awesomeness, this may take some time :)"
|
||||
"starting model server - loading some awesomeness, this may take some time :)"
|
||||
)
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
|
|
@ -61,7 +73,7 @@ def start_server(port=51000):
|
|||
log.info(f"Model server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
log.info("Model server - Didn't Sart In Time. Shutting Down")
|
||||
log.info("model server - didn't start in time, shutting down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -67,12 +67,16 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(
|
||||
resp.choices[0].message.content
|
||||
|
|
|
|||
1192
model_server/poetry.lock
generated
1192
model_server/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
description = "A model server for serving models"
|
||||
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue