Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest {
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
@ -80,6 +80,8 @@ pub struct FunctionParameter {
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
}
impl Serialize for FunctionParameter {
@ -96,6 +98,9 @@ impl Serialize for FunctionParameter {
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
if let Some(format) = &self.format {
map.serialize_entry("format", format)?;
}
map.end()
}
}
@ -165,8 +170,8 @@ pub struct Message {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub finish_reason: Option<String>,
pub index: Option<usize>,
pub message: Message,
}
@ -197,6 +202,18 @@ pub struct ToolCallState {
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum ModelServerResponse {
ChatCompletionsResponse(ChatCompletionsResponse),
ModelServerErrorResponse(ModelServerErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelServerErrorResponse {
pub result: String,
pub intent_latency: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
@ -217,8 +234,8 @@ impl ChatCompletionsResponse {
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
index: Some(0),
finish_reason: Some("done".to_string()),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
@ -408,6 +425,7 @@ mod test {
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
format: None,
},
);
@ -462,6 +480,7 @@ mod test {
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
format: None,
},
)]);