From bd119de2c1c324508ea634f954dbc4c014a08821 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 22 Dec 2023 09:51:26 +0800 Subject: [PATCH] format general_api_requestor params type --- metagpt/provider/general_api_requestor.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 8b06b9388..cf31fd629 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -3,7 +3,7 @@ # @Desc : General Async API for http-based LLM model import asyncio -from typing import AsyncGenerator, Generator, Iterator, Optional, Tuple, Union +from typing import AsyncGenerator, Generator, Iterator, Tuple, Union import aiohttp import requests @@ -12,7 +12,7 @@ from metagpt.logs import logger from metagpt.provider.general_api_base import APIRequestor -def parse_stream_helper(line: bytes) -> Optional[str]: +def parse_stream_helper(line: bytes) -> Union[bytes, None]: if line and line.startswith(b"data:"): if line.startswith(b"data: "): # SSE event may be valid when it contain whitespace @@ -24,11 +24,11 @@ def parse_stream_helper(line: bytes) -> Optional[str]: # and it will close http connection with TCP Reset return None else: - return line.decode("utf-8") + return line return None -def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: +def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]: for line in rbody: _line = parse_stream_helper(line) if _line is not None: @@ -50,7 +50,7 @@ class GeneralAPIRequestor(APIRequestor): ) """ - def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str: + def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes: # just do nothing to meet the APIRequestor process and return the raw data # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases. @@ -58,7 +58,7 @@ class GeneralAPIRequestor(APIRequestor): def _interpret_response( self, result: requests.Response, stream: bool - ) -> Tuple[Union[str, Iterator[Generator]], bool]: + ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: """Returns the response(s) and a bool indicating whether it is a stream.""" if stream and "text/event-stream" in result.headers.get("Content-Type", ""): return ( @@ -78,7 +78,7 @@ class GeneralAPIRequestor(APIRequestor): async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool - ) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]: + ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]: if stream and ( "text/event-stream" in result.headers.get("Content-Type", "") or "application/x-ndjson" in result.headers.get("Content-Type", "")