format general_api_requestor params type

This commit is contained in:
better629 2023-12-22 09:51:26 +08:00 committed by geekan
parent 2502dd3651
commit bd119de2c1

View file

@ -3,7 +3,7 @@
# @Desc : General Async API for http-based LLM model
import asyncio
from typing import AsyncGenerator, Generator, Iterator, Optional, Tuple, Union
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
import aiohttp
import requests
@ -12,7 +12,7 @@ from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
def parse_stream_helper(line: bytes) -> Optional[str]:
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
@ -24,11 +24,11 @@ def parse_stream_helper(line: bytes) -> Optional[str]:
# and it will close http connection with TCP Reset
return None
else:
return line.decode("utf-8")
return line
return None
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
@ -50,7 +50,7 @@ class GeneralAPIRequestor(APIRequestor):
)
"""
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str:
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
# just do nothing to meet the APIRequestor process and return the raw data
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
@ -58,7 +58,7 @@ class GeneralAPIRequestor(APIRequestor):
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[str, Iterator[Generator]], bool]:
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
@ -78,7 +78,7 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")