fix feature

This commit is contained in:
rainyrfeng 2023-08-31 11:05:59 +08:00
parent 65500204b8
commit 949a5074ce
3 changed files with 10 additions and 5 deletions

View file

@ -16,11 +16,12 @@ RPM: 10
#Anthropic_API_KEY: "YOUR_API_KEY"
#### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb
#### You can use ENGINE or DEPLOYMENT mode
#OPENAI_API_TYPE: "azure"
#OPENAI_API_BASE: "YOUR_AZURE_ENDPOINT"
#OPENAI_API_KEY: "YOUR_AZURE_API_KEY"
#OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION"
#OPENAI_API_ENGINE: "YOUR_OPENAI_API_ENGINE"
#DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID"
#### for Search

View file

@ -56,6 +56,7 @@ class Config(metaclass=Singleton):
openai.api_base = self.openai_api_base
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")
self.openai_api_engine = self._get('OPENAI_API_ENGINE')
self.openai_api_rpm = self._get("RPM", 3)
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4")
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)

View file

@ -156,10 +156,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
@ -170,6 +172,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def _cons_kwargs(self, messages: list[dict]) -> dict:
if CONFIG.openai_api_type == "azure":
kwargs = {
"engine": CONFIG.openai_api_engine,
"deployment_id": CONFIG.deployment_id,
"messages": messages,
"max_tokens": self.get_max_tokens(messages),