mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 23:02:43 +02:00
Add the ability to use LLM Providers from the Arch config (#112)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
1b57a49c9d
commit
8ea917aae5
16 changed files with 295 additions and 210 deletions
|
|
@ -1,47 +1,69 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-arch-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-arch-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
self.default.as_ref().map(|rc| rc.clone())
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).map(|rc| rc.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmProvidersNewError {
|
||||
#[error("There must be at least one LLM Provider")]
|
||||
EmptySource,
|
||||
#[error("There must be at most one default LLM Provider")]
|
||||
MoreThanOneDefault,
|
||||
#[error("\'{0}\' is not a unique name")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
||||
type Error = LlmProvidersNewError;
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
|
||||
if llm_providers_config.is_empty() {
|
||||
return Err(LlmProvidersNewError::EmptySource);
|
||||
}
|
||||
|
||||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue