fixbug: Fix the confusion caused by the merging of _client, client, and async_client in the openai_api.py;Split Azure LLM and MetaGPT LLM from OpenAI LLM to reduce the number of variables defined in the Config class for compatibility.

This commit is contained in:
莘权 马 2023-12-23 17:45:10 +08:00
parent 4b4ccc60fc
commit a90f52d4b6
8 changed files with 143 additions and 161 deletions

View file

@ -80,26 +80,41 @@ class Config(metaclass=Singleton):
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
for k, v in [
(self.openai_api_key, LLMProviderEnum.OPENAI),
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
]:
if self._is_valid_llm_key(k):
# logger.debug(f"Use LLMProvider: {v.value}")
if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
if self.openai_api_key and self.openai_api_model:
logger.info(f"OpenAI API Model: {self.openai_api_model}")
return v
mappings = {
LLMProviderEnum.OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
),
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
LLMProviderEnum.METAGPT: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
),
LLMProviderEnum.AZURE_OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY)
and self.OPENAI_API_TYPE == "azure"
and self.DEPLOYMENT_NAME
and self.OPENAI_API_VERSION
),
}
provider = None
for k, v in mappings.items():
if v:
provider = k
break
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
if provider:
logger.info(f"API: {provider}")
return provider
raise NotConfiguredException("You should config a LLM configuration first")
@staticmethod
def _is_valid_llm_key(k: str) -> bool:
return k and k != "YOUR_API_KEY"
return bool(k and k != "YOUR_API_KEY")
def _update(self):
self.global_proxy = self._get("GLOBAL_PROXY")
@ -113,7 +128,7 @@ class Config(metaclass=Singleton):
self.gemini_api_key = self._get("GEMINI_API_KEY")
_ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")
# self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")

View file

@ -11,5 +11,15 @@ from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
from metagpt.provider.metagpt_api import METAGPTAPI
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]
__all__ = [
"FireWorksGPTAPI",
"GeminiGPTAPI",
"OpenLLMGPTAPI",
"OpenAIGPTAPI",
"ZhiPuAIGPTAPI",
"AzureOpenAIGPTAPI",
"METAGPTAPI",
]

View file

@ -26,26 +26,22 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.config: Config = CONFIG
self.__init_openai()
self._init_openai()
self.auto_max_tokens = False
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self._client = AsyncAzureOpenAI(
api_key=CONFIG.openai_api_key,
api_version=CONFIG.openai_api_version,
azure_endpoint=CONFIG.openai_api_base,
)
RateLimiter.__init__(self, rpm=self.rpm)
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(
api_key=self.config.openai_api_key,
api_version=self.config.openai_api_version,
azure_endpoint=self.config.openai_base_url,
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
)
async_kwargs = kwargs.copy()
@ -64,7 +60,7 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI):
"n": 1,
"stop": None,
"temperature": 0.3,
"model": CONFIG.deployment_id,
"model": self.model,
}
if configs:
kwargs.update(configs)

View file

@ -87,31 +87,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def __init__(self):
self.config: Config = CONFIG
self.__init_openai()
self._init_openai()
self.auto_max_tokens = False
# https://github.com/openai/openai-python#async-usage
self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base)
RateLimiter.__init__(self, rpm=self.rpm)
# async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
# kwargs = self._cons_kwargs(messages, timeout=timeout)
# response = await self._client.chat.completions.create(**kwargs, stream=True)
# # iterate through the stream of events
# async for chunk in response:
# chunk_message = chunk.choices[0].delta.content or "" # extract the message
# yield chunk_message
def __init_openai(self):
self.rpm = int(self.config.get("RPM", 10))
def _init_openai(self):
self.rpm = int(self.config.RPM or 10)
self._make_client()
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
# https://github.com/openai/openai-python#async-usage
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
@ -126,33 +118,19 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
if self.config.OPENAI_BASE_URL:
params["base_url"] = self.config.OPENAI_BASE_URL
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
if chunk.choices:
chunk_message = chunk.choices[0].delta # extract the message
collected_messages.append(chunk_message) # save the message
if chunk_message.content:
print(chunk_message.content, end="")
print()
full_reply_content = "".join([m.content for m in collected_messages if m.content])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
chunk_message = chunk.choices[0].delta.content or "" # extract the message
yield chunk_message
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
kwargs = {
@ -161,7 +139,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.config.openai_api_model,
"model": self.model,
}
if configs:
kwargs.update(configs)
@ -175,13 +153,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=timeout)
rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
self._update_costs(rsp.usage)
return rsp
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
loop = self.get_event_loop()
return loop.run_until_complete(self.acompletion(messages, timeout=timeout))
return self._chat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)
@ -234,12 +216,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
loop = self.get_event_loop()
return loop.run_until_complete(self._achat_completion_function(messages=messages, timeout=timeout, **kwargs))
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.usage)
return rsp
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
@ -295,25 +278,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
try:
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
except openai.NotFoundError as e:
logger.error(f"API TYPE:{CONFIG.openai_api_type}, err:{e}")
except openai.BadRequestError as e:
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
raise e
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
if CONFIG.calc_usage:
try:
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage = CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return usage
except Exception as e:
logger.error(f"{self.model} usage calculation failed!", e)
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
@ -384,31 +352,20 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
def moderation(self, content: Union[str, list[str]]):
loop = self.get_event_loop()
loop.run_until_complete(self.amoderation(content=content))
return self.client.moderations.create(input=content)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
return await self._client.moderations.create(input=content)
return await self.async_client.moderations.create(input=content)
async def close(self):
"""Close connection"""
if not self._client:
return
await self._client.close()
self._client = None
@staticmethod
def get_event_loop():
try:
return asyncio.get_event_loop()
except RuntimeError as e:
if "There is no current event loop in thread" in str(e):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
else:
raise e
if self.client:
self.client.close()
self.client = None
if self.async_client:
await self.async_client.close()
self.async_client = None
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs) -> str:
max_token_count = DEFAULT_MAX_TOKENS

View file

@ -18,15 +18,15 @@ from metagpt.config import CONFIG
def make_sk_kernel():
kernel = sk.Kernel()
if CONFIG.openai_api_type == "azure":
if CONFIG.OPENAI_API_TYPE == "azure":
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_base_url, CONFIG.openai_api_key),
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
)
else:
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(CONFIG.openai_api_model, CONFIG.openai_api_key),
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
)
return kernel

View file

@ -15,15 +15,15 @@ import pytest
from metagpt.config import CONFIG, Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
from metagpt.utils.git_repository import GitRepository
class Context:
def __init__(self):
self._llm_ui = None
self._llm_api = GPTAPI()
self._llm_api = LLM(provider=CONFIG.get_default_llm_provider_enum())
@property
def llm_api(self):

View file

@ -5,47 +5,47 @@
@Author : mashenquan
@File : test_brain_memory.py
"""
import json
from typing import List
import pydantic
from metagpt.memory.brain_memory import BrainMemory
from metagpt.schema import Message
def test_json():
class Input(pydantic.BaseModel):
history: List[str]
solution: List[str]
knowledge: List[str]
stack: List[str]
inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}]
for i in inputs:
v = Input(**i)
bm = BrainMemory()
for h in v.history:
msg = Message(content=h)
bm.history.append(msg.dict())
for h in v.solution:
msg = Message(content=h)
bm.solution.append(msg.dict())
for h in v.knowledge:
msg = Message(content=h)
bm.knowledge.append(msg.dict())
for h in v.stack:
msg = Message(content=h)
bm.stack.append(msg.dict())
s = bm.json()
m = json.loads(s)
bm = BrainMemory(**m)
assert bm
for v in bm.history:
msg = Message(**v)
assert msg
if __name__ == "__main__":
test_json()
# import json
# from typing import List
#
# import pydantic
#
# from metagpt.memory.brain_memory import BrainMemory
# from metagpt.schema import Message
#
#
# def test_json():
# class Input(pydantic.BaseModel):
# history: List[str]
# solution: List[str]
# knowledge: List[str]
# stack: List[str]
#
# inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}]
#
# for i in inputs:
# v = Input(**i)
# bm = BrainMemory()
# for h in v.history:
# msg = Message(content=h)
# bm.history.append(msg.dict())
# for h in v.solution:
# msg = Message(content=h)
# bm.solution.append(msg.dict())
# for h in v.knowledge:
# msg = Message(content=h)
# bm.knowledge.append(msg.dict())
# for h in v.stack:
# msg = Message(content=h)
# bm.stack.append(msg.dict())
# s = bm.json()
# m = json.loads(s)
# bm = BrainMemory(**m)
# assert bm
# for v in bm.history:
# msg = Message(**v)
# assert msg
#
#
# if __name__ == "__main__":
# test_json()

View file

@ -28,27 +28,31 @@ class TestGPT:
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"])
logger.info(answer)
assert len(answer) > 0
except openai.NotFoundError:
assert CONFIG.openai_api_type == "azure"
except openai.BadRequestError:
assert CONFIG.OPENAI_API_TYPE == "azure"
@pytest.mark.asyncio
async def test_llm_api_aask(self, llm_api):
answer = await llm_api.aask("hello chatgpt")
answer = await llm_api.aask("hello chatgpt", stream=False)
logger.info(answer)
assert len(answer) > 0
answer = await llm_api.aask("hello chatgpt", stream=True)
logger.info(answer)
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_aask_code(self, llm_api):
try:
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"])
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"], timeout=60)
logger.info(answer)
assert len(answer) > 0
except openai.NotFoundError:
assert CONFIG.openai_api_type == "azure"
except openai.BadRequestError:
assert CONFIG.OPENAI_API_TYPE == "azure"
@pytest.mark.asyncio
async def test_llm_api_costs(self, llm_api):
await llm_api.aask("hello chatgpt")
await llm_api.aask("hello chatgpt", stream=False)
costs = llm_api.get_costs()
logger.info(costs)
assert costs.total_cost > 0