diff --git a/config/config.yaml b/config/config.yaml index 274cdf469..428f8dae4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -16,11 +16,12 @@ RPM: 10 #Anthropic_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb - +#### You can use ENGINE or DEPLOYMENT mode #OPENAI_API_TYPE: "azure" #OPENAI_API_BASE: "YOUR_AZURE_ENDPOINT" #OPENAI_API_KEY: "YOUR_AZURE_API_KEY" #OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION" +#OPENAI_API_ENGINE: "YOUR_OPENAI_API_ENGINE" #DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID" #### for Search diff --git a/metagpt/config.py b/metagpt/config.py index 2c1096877..16c4117cd 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -56,6 +56,7 @@ class Config(metaclass=Singleton): openai.api_base = self.openai_api_base self.openai_api_type = self._get("OPENAI_API_TYPE") self.openai_api_version = self._get("OPENAI_API_VERSION") + self.openai_api_engine = self._get('OPENAI_API_ENGINE') self.openai_api_rpm = self._get("RPM", 3) self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4") self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 79121c8de..7ef694a98 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -156,10 +156,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # iterate through the stream of events async for chunk in response: collected_chunks.append(chunk) # save the event response - chunk_message = chunk["choices"][0]["delta"] # extract the message - collected_messages.append(chunk_message) # save the message - if "content" in chunk_message: - print(chunk_message["content"], end="") + choices = chunk["choices"] + if len(choices) > 0: + chunk_message = chunk["choices"][0].get("delta", {}) # extract the message + collected_messages.append(chunk_message) # save the message + if "content" in chunk_message: + print(chunk_message["content"], end="") print() full_reply_content = "".join([m.get("content", "") for m in collected_messages]) @@ -170,6 +172,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def _cons_kwargs(self, messages: list[dict]) -> dict: if CONFIG.openai_api_type == "azure": kwargs = { + "engine": CONFIG.openai_api_engine, "deployment_id": CONFIG.deployment_id, "messages": messages, "max_tokens": self.get_max_tokens(messages),