fix ollama to add llava support

This commit is contained in:
EvensXia 2024-10-28 17:55:50 +08:00
parent 8b209d4e17
commit 3fefc48efa
3 changed files with 121 additions and 101 deletions

View file

@ -320,49 +320,6 @@ class APIRequestor:
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
async def arequest(
self,

View file

@ -3,25 +3,24 @@
# @Desc : General Async API for http-based LLM model
import asyncio
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
from typing import AsyncGenerator, Iterator, Optional, Tuple, Union
import aiohttp
import requests
from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
def parse_stream_helper(line: bytes) -> Optional[bytes]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
# SSE event may be valid when it contains whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
# Returning None to indicate end of stream
return None
else:
return line
@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
class GeneralAPIRequestor(APIRequestor):
"""
usage
Usage example:
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
)
"""
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
# just do nothing to meet the APIRequestor process and return the raw data
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
"""
Process and return the response data wrapped in OpenAIResponse.
return rbody
Args:
rbody (bytes): The response body.
rcode (int): The response status code.
rheaders (dict): The response headers.
stream (bool): Whether the response is a stream.
Returns:
OpenAIResponse: The response data wrapped in OpenAIResponse.
"""
return OpenAIResponse(rbody, rheaders)
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""
Interpret a synchronous response.
Args:
result (requests.Response): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
return (
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
), True
(
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
),
True,
)
else:
return (
self._interpret_response_line(
result.content, # let the caller to decode the msg
result.content, # let the caller decode the msg
result.status_code,
result.headers,
stream=False,
@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""
Interpret an asynchronous response.
Args:
result (aiohttp.ClientResponse): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
), True
(
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
),
True,
)
else:
try:
await result.read()
response_content = await result.read()
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise TimeoutError("Request timed out") from e
except aiohttp.ClientError as exp:
logger.warning(f"response: {result.content}, exp: {exp}")
logger.warning(f"response: {result}, exp: {exp}")
response_content = b""
return (
self._interpret_response_line(
await result.read(), # let the caller to decode the msg
response_content, # let the caller decode the msg
result.status,
result.headers,
stream=False,

View file

@ -3,12 +3,13 @@
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
import json
from typing import AsyncGenerator, Tuple
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import TokenCostManager
@ -34,65 +35,68 @@ class OllamaLLM(BaseLLM):
self.pricing_plan = self.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})
assert assist_msg.get("role", None) == "assistant"
return assist_msg.get("content")
if assist_msg.get("role", None) == "assistant": # chat
return assist_msg.get("content")
else: # llava
return resp["response"]
def get_usage(self, resp: dict) -> dict:
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
chunk = chunk.decode(encoding)
return json.loads(chunk)
def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict:
return json.loads(openai_resp.data.decode(encoding))
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
messages, fixed_suffix_url = self._apply_llava(messages)
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
params=self._const_kwargs(messages),
url=fixed_suffix_url,
headers=self._get_headers(),
params=messages,
request_timeout=self.get_timeout(timeout),
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
raise NotImplementedError
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
messages, fixed_suffix_url = self._apply_llava(messages, stream=True)
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
url=fixed_suffix_url,
headers=self._get_headers(),
stream=True,
params=self._const_kwargs(messages, stream=True),
params=messages,
request_timeout=self.get_timeout(timeout),
)
if isinstance(stream_resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(stream_resp)
elif isinstance(stream_resp, OpenAIResponse):
return self._processing_openai_response(stream_resp)
else:
raise NotImplementedError
def _processing_openai_response(self, openai_resp: OpenAIResponse):
resp = self._decode_and_load(openai_resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
collected_content = []
usage = {}
async for raw_chunk in stream_resp:
async for raw_chunk in ag_openai_resp:
chunk = self._decode_and_load(raw_chunk)
if not chunk.get("done", False):
@ -107,3 +111,29 @@ class OllamaLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
def _get_headers(self):
return (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {"Authorization": f"Bearer {self.config.api_key}"}
)
def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]:
llava = False
if isinstance(messages[0]["content"], str):
return self._const_kwargs(messages, stream), self.suffix_url
if any(len(msg["content"]) >= 2 for msg in messages):
assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type"
llava = True
if not llava:
return self._const_kwargs(messages, stream), self.suffix_url
assert len(messages) <= 1, "not support batch massages in llava calling images"
contents = messages[0]["content"]
return {
"model": self.model,
"prompt": contents[0]["text"],
"images": [i["image_url"]["url"][23:] for i in contents[1:]],
}, "/generate"