This commit is contained in:
ziming 2023-09-05 22:27:21 +08:00
commit 761d60f26d
3 changed files with 25 additions and 22 deletions

View file

@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self.openai_api_rpm = self._get("RPM", 3)
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4")
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
self.deployment_name = self._get('DEPLOYMENT_NAME')
self.deployment_id = self._get("DEPLOYMENT_ID")
self.claude_api_key = self._get("Anthropic_API_KEY")

View file

@ -162,10 +162,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
@ -174,25 +176,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3
}
if CONFIG.openai_api_type == "azure":
kwargs = {
"deployment_id": CONFIG.deployment_id,
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
}
if CONFIG.deployment_name and CONFIG.deployment_id:
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
kwargs_mode = {"engine": CONFIG.deployment_name} if CONFIG.deployment_name \
else {"deployment_id": CONFIG.deployment_id}
else:
kwargs = {
"model": self.model,
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
}
kwargs["timeout"] = 3
kwargs_mode = {"model": self.model}
kwargs.update(kwargs_mode)
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict: