Add the ability to use LLM Providers from the Arch config (#112)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-03 10:57:01 -07:00 committed by GitHub
parent 1b57a49c9d
commit 8ea917aae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 295 additions and 210 deletions

View file

@ -1,47 +1,69 @@
#[non_exhaustive]
pub struct LlmProviders;
use public_types::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
}
impl LlmProviders {
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "openai",
api_key_header: "x-arch-openai-api-key",
model: "gpt-3.5-turbo",
};
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "mistral",
api_key_header: "x-arch-mistral-api-key",
model: "mistral-large-latest",
};
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
self.providers.iter()
}
pub const VARIANTS: &'static [LlmProvider<'static>] =
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
}
pub fn default(&self) -> Option<Rc<LlmProvider>> {
self.default.as_ref().map(|rc| rc.clone())
}
pub struct LlmProvider<'prov> {
name: &'prov str,
api_key_header: &'prov str,
model: &'prov str,
}
impl AsRef<str> for LlmProvider<'_> {
fn as_ref(&self) -> &str {
self.name
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).map(|rc| rc.clone())
}
}
impl std::fmt::Display for LlmProvider<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
#[derive(thiserror::Error, Debug)]
pub enum LlmProvidersNewError {
#[error("There must be at least one LLM Provider")]
EmptySource,
#[error("There must be at most one default LLM Provider")]
MoreThanOneDefault,
#[error("\'{0}\' is not a unique name")]
DuplicateName(String),
}
impl LlmProvider<'_> {
pub fn api_key_header(&self) -> &str {
self.api_key_header
}
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
type Error = LlmProvidersNewError;
pub fn choose_model(&self) -> &str {
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
self.model
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
if llm_providers_config.is_empty() {
return Err(LlmProvidersNewError::EmptySource);
}
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
Ok(llm_providers)
}
}