Remove top level container and start snake-case for config files (#27)

t
This commit is contained in:
Adil Hafeez 2024-07-31 14:05:52 -07:00 committed by GitHub
parent b8ea65d858
commit 9774148c75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 71 additions and 133 deletions

View file

@ -1,28 +1,7 @@
use serde::{Deserialize, Serialize};
//TODO: possibly use protbuf to enforce schema
//FIX: it is unnecessary to place yaml config inside katanemo-prompt-config
//GH Issue: https://github.com/katanemo/intelligent-prompt-gateway/issues/7
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Configuration {
#[serde(rename = "katanemo-prompt-config")]
pub prompt_config: PromptConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancing {
#[serde(rename = "round-robin")]
RoundRobin,
#[serde(rename = "random")]
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptConfig {
pub default_prompt_endpoint: String,
pub load_balancing: LoadBalancing,
pub timeout_ms: u64,
@ -33,7 +12,14 @@ pub struct PromptConfig {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum LoadBalancing {
#[serde(rename = "round_robin")]
RoundRobin,
#[serde(rename = "random")]
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct EmbeddingProviver {
pub name: String,
@ -41,7 +27,6 @@ pub struct EmbeddingProviver {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
@ -50,7 +35,6 @@ pub struct LlmProvider {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Endpoint {
pub cluster: String,
pub path: Option<String>,
@ -58,7 +42,6 @@ pub struct Endpoint {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct EntityDetail {
pub name: String,
pub required: Option<bool>,
@ -73,7 +56,6 @@ pub enum EntityType {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: String,
@ -87,49 +69,48 @@ pub struct PromptTarget {
#[cfg(test)]
mod test {
pub const CONFIGURATION: &str = r#"
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
embedding_provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
llm_providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
- name: "open-ai-gpt-4"
api_key: "$OPEN_AI_API_KEY"
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
system_prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt-targets:
prompt_targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
required: true
description: "The location for which the weather is requested"
- type: context_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
required: true
description: "The location for which the weather is requested"
- type: context-resolver
name: weather-forecast-2
few-shot-examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- city
- type: context_resolver
name: weather_forecast_2
few_shot_examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- city
"#;
#[test]

View file

@ -46,7 +46,7 @@ impl FilterContext {
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(

View file

@ -281,7 +281,7 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
panic!("Error dispatching HTTP call for context_resolver: {:?}", e);
}
};
callout_context.request_type = RequestType::ContextResolver;
@ -291,7 +291,7 @@ impl StreamContext {
}
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
info!("response received for context-resolver");
info!("response received for context_resolver");
let body_string = String::from_utf8(body);
let prompt_target = callout_context.prompt_target.unwrap();
let mut request_body = callout_context.request_body;
@ -307,7 +307,7 @@ impl StreamContext {
}
match body_string {
Ok(body_string) => {
info!("context-resolver response: {}", body_string);
info!("context_resolver response: {}", body_string);
let context_resolver_response = Message {
role: USER_ROLE.to_string(),
content: Some(body_string),