mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 12:52:56 +02:00
Remove top level container and start snake-case for config files (#27)
t
This commit is contained in:
parent
b8ea65d858
commit
9774148c75
6 changed files with 71 additions and 133 deletions
|
|
@ -1,28 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
//TODO: possibly use protbuf to enforce schema
|
||||
|
||||
//FIX: it is unnecessary to place yaml config inside katanemo-prompt-config
|
||||
//GH Issue: https://github.com/katanemo/intelligent-prompt-gateway/issues/7
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct Configuration {
|
||||
#[serde(rename = "katanemo-prompt-config")]
|
||||
pub prompt_config: PromptConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LoadBalancing {
|
||||
#[serde(rename = "round-robin")]
|
||||
RoundRobin,
|
||||
#[serde(rename = "random")]
|
||||
Random,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct PromptConfig {
|
||||
pub default_prompt_endpoint: String,
|
||||
pub load_balancing: LoadBalancing,
|
||||
pub timeout_ms: u64,
|
||||
|
|
@ -33,7 +12,14 @@ pub struct PromptConfig {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum LoadBalancing {
|
||||
#[serde(rename = "round_robin")]
|
||||
RoundRobin,
|
||||
#[serde(rename = "random")]
|
||||
Random,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct EmbeddingProviver {
|
||||
pub name: String,
|
||||
|
|
@ -41,7 +27,6 @@ pub struct EmbeddingProviver {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
|
|
@ -50,7 +35,6 @@ pub struct LlmProvider {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
|
|
@ -58,7 +42,6 @@ pub struct Endpoint {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct EntityDetail {
|
||||
pub name: String,
|
||||
pub required: Option<bool>,
|
||||
|
|
@ -73,7 +56,6 @@ pub enum EntityType {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct PromptTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub prompt_type: String,
|
||||
|
|
@ -87,49 +69,48 @@ pub struct PromptTarget {
|
|||
#[cfg(test)]
|
||||
mod test {
|
||||
pub const CONFIGURATION: &str = r#"
|
||||
katanemo-prompt-config:
|
||||
default-prompt-endpoint: "127.0.0.1"
|
||||
load-balancing: "round-robin"
|
||||
timeout-ms: 5000
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
|
||||
embedding-provider:
|
||||
name: "SentenceTransformer"
|
||||
model: "all-MiniLM-L6-v2"
|
||||
embedding_provider:
|
||||
name: "SentenceTransformer"
|
||||
model: "all-MiniLM-L6-v2"
|
||||
|
||||
llm-providers:
|
||||
llm_providers:
|
||||
|
||||
- name: "open-ai-gpt-4"
|
||||
api-key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
- name: "open-ai-gpt-4"
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
system-prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
prompt-targets:
|
||||
prompt_targets:
|
||||
|
||||
- type: context-resolver
|
||||
name: weather-forecast
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
- type: context_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: context-resolver
|
||||
name: weather-forecast-2
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- city
|
||||
- type: context_resolver
|
||||
name: weather_forecast_2
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- city
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ impl FilterContext {
|
|||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
|
||||
panic!("Error dispatching HTTP call for context_resolver: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::ContextResolver;
|
||||
|
|
@ -291,7 +291,7 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
info!("response received for context-resolver");
|
||||
info!("response received for context_resolver");
|
||||
let body_string = String::from_utf8(body);
|
||||
let prompt_target = callout_context.prompt_target.unwrap();
|
||||
let mut request_body = callout_context.request_body;
|
||||
|
|
@ -307,7 +307,7 @@ impl StreamContext {
|
|||
}
|
||||
match body_string {
|
||||
Ok(body_string) => {
|
||||
info!("context-resolver response: {}", body_string);
|
||||
info!("context_resolver response: {}", body_string);
|
||||
let context_resolver_response = Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_string),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue