more changes

This commit is contained in:
Adil Hafeez 2025-06-04 10:58:09 -07:00
parent b0c1e97dc5
commit 670907145a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
4 changed files with 20 additions and 118 deletions

View file

@ -34,7 +34,10 @@ pub async fn chat_completions(
match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) {
Ok(request) => request,
Err(err) => {
warn!("arch-router request body string: {}", String::from_utf8_lossy(&chat_request_bytes));
warn!(
"arch-router request body string: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let err_msg = format!("Failed to parse request body: {}", err);
warn!("{}", err_msg);
let mut bad_request = Response::new(full(err_msg));

View file

@ -143,7 +143,6 @@ impl RouterModel for RouterModelV1 {
messages: vec![Message {
content: Some(ContentType::Text(messages_content)),
role: USER_ROLE.to_string(),
}],
..Default::default()
}

View file

@ -1,9 +1,12 @@
pub mod types;
pub mod builder;
use thiserror::Error;
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub type OpenAIRequestBuilder = builder::OpenAIRequestBuilder;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]

View file

@ -64,7 +64,7 @@ pub struct StreamOptions {
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatCompletionsRequest {
pub model: String,
pub messages: Vec<Message>,
@ -79,23 +79,6 @@ pub struct ChatCompletionsRequest {
pub stream_options: Option<StreamOptions>,
}
impl Default for ChatCompletionsRequest {
fn default() -> Self {
ChatCompletionsRequest {
model: String::new(),
messages: Vec::new(),
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
stream_options: None,
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -123,106 +106,20 @@ pub struct Usage {
pub total_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct OpenAIRequestBuilder {
model: String,
messages: Vec<Message>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
max_tokens: Option<u32>,
stream: Option<bool>,
stop: Option<Vec<String>>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
stream_options: Option<StreamOptions>,
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: Message,
pub finish_reason: Option<String>,
}
impl OpenAIRequestBuilder {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
stream_options: None,
}
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn stream_options(mut self, include_usage: bool) -> Self {
self.stream = Some(true);
self.stream_options = Some(StreamOptions { include_usage });
self
}
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
let request = ChatCompletionsRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
n: self.n,
max_tokens: self.max_tokens,
stream: self.stream,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
stream_options: self.stream_options,
};
Ok(request)
}
}
impl ChatCompletionsRequest {
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model, messages)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
}
#[cfg(test)]