fix openai

This commit is contained in:
geekan 2023-12-28 17:42:28 +08:00
parent 82071d4774
commit fe697ac095
3 changed files with 8 additions and 14 deletions

View file

@ -143,7 +143,7 @@ class Config(metaclass=Singleton):
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
# self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")

View file

@ -69,7 +69,7 @@ class OpenAILLM(BaseLLM):
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
@ -81,8 +81,8 @@ class OpenAILLM(BaseLLM):
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.OPENAI_BASE_URL:
params["base_url"] = self.config.OPENAI_BASE_URL
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
return params

View file

@ -86,31 +86,25 @@ class TestOpenAI:
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs