mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
adding support for wildcard model providers
This commit is contained in:
parent
86cf8ccdaa
commit
34711c6f9d
14 changed files with 1027 additions and 1823 deletions
|
|
@ -187,11 +187,21 @@ def validate_and_render_schema():
|
|||
|
||||
model_name = model_provider.get("model")
|
||||
print("Processing model_provider: ", model_provider)
|
||||
if model_name in model_name_keys:
|
||||
|
||||
# Check if this is a wildcard model (provider/*)
|
||||
is_wildcard = False
|
||||
if "/" in model_name:
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) >= 2 and model_name_tokens[-1] == "*":
|
||||
is_wildcard = True
|
||||
|
||||
if model_name in model_name_keys and not is_wildcard:
|
||||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_name)
|
||||
|
||||
if not is_wildcard:
|
||||
model_name_keys.add(model_name)
|
||||
if model_provider.get("name") is None:
|
||||
model_provider["name"] = model_name
|
||||
|
||||
|
|
@ -202,7 +212,21 @@ def validate_and_render_schema():
|
|||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
provider = model_name_tokens[0].strip()
|
||||
|
||||
# Check if this is a wildcard (provider/*)
|
||||
is_wildcard = model_name_tokens[-1].strip() == "*"
|
||||
|
||||
# Validate wildcard constraints
|
||||
if is_wildcard:
|
||||
if model_provider.get("default", False):
|
||||
raise Exception(
|
||||
f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards."
|
||||
)
|
||||
if model_provider.get("routing_preferences"):
|
||||
raise Exception(
|
||||
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
|
|
@ -213,7 +237,9 @@ def validate_and_render_schema():
|
|||
)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
|
||||
# For wildcard providers, allow any provider name
|
||||
if not is_wildcard and provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
model_provider.get("base_url", None) is None
|
||||
or model_provider.get("provider_interface", None) is None
|
||||
|
|
@ -227,11 +253,13 @@ def validate_and_render_schema():
|
|||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
# For wildcard models, don't add model_id to the keys since it's "*"
|
||||
if not is_wildcard:
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue