Merge pull request #1544 from EvensXia/fix_metagpt_from_evensxia

fix ollama_api to add llava support
This commit is contained in:
Alexander Wu 2024-10-30 10:14:53 +08:00 committed by GitHub
commit 9db0874102
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 345 additions and 130 deletions

View file

@ -15,8 +15,8 @@ async def main():
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
assert "true" in res.lower()
res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
assert ("true" in res.lower()) or ("invoice" in res.lower())
if __name__ == "__main__":

View file

@ -26,7 +26,10 @@ class LLMType(Enum):
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
OLLAMA = "ollama" # /chat at ollama api
OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api
OLLAMA_EMBED = "ollama.embed" # /embed at ollama api
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
@ -105,7 +108,8 @@ class LLMConfig(YamlModel):
root_config_path = CONFIG_ROOT / "config2.yaml"
if root_config_path.exists():
raise ValueError(
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n"
f"the former will overwrite the latter. This may cause unexpected result.\n"
)
elif repo_config_path.exists():
raise ValueError(f"Please set your API key in {repo_config_path}")

View file

@ -13,6 +13,7 @@ import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
@ -121,7 +122,7 @@ def logfmt(props):
class OpenAIResponse:
def __init__(self, data, headers):
def __init__(self, data: Union[bytes, Any], headers: dict):
self._headers = headers
self.data = data
@ -320,49 +321,6 @@ class APIRequestor:
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
async def arequest(
self,
@ -438,8 +396,8 @@ class APIRequestor:
"X-LLM-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
headers.update(api_key_to_header(self.api_type, self.api_key))
if self.api_key:
headers.update(api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["LLM-Organization"] = self.organization

View file

@ -3,25 +3,24 @@
# @Desc : General Async API for http-based LLM model
import asyncio
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
from typing import AsyncGenerator, Iterator, Optional, Tuple, Union
import aiohttp
import requests
from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
def parse_stream_helper(line: bytes) -> Optional[bytes]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
# SSE event may be valid when it contains whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
# Returning None to indicate end of stream
return None
else:
return line
@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
class GeneralAPIRequestor(APIRequestor):
"""
usage
Usage example:
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
)
"""
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
# just do nothing to meet the APIRequestor process and return the raw data
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
"""
Process and return the response data wrapped in OpenAIResponse.
return rbody
Args:
rbody (bytes): The response body.
rcode (int): The response status code.
rheaders (dict): The response headers.
stream (bool): Whether the response is a stream.
Returns:
OpenAIResponse: The response data wrapped in OpenAIResponse.
"""
return OpenAIResponse(rbody, rheaders)
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""
Interpret a synchronous response.
Args:
result (requests.Response): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
return (
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
), True
(
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
),
True,
)
else:
return (
self._interpret_response_line(
result.content, # let the caller to decode the msg
result.content, # let the caller decode the msg
result.status_code,
result.headers,
stream=False,
@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""
Interpret an asynchronous response.
Args:
result (aiohttp.ClientResponse): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
), True
(
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
),
True,
)
else:
try:
await result.read()
response_content = await result.read()
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise TimeoutError("Request timed out") from e
except aiohttp.ClientError as exp:
logger.warning(f"response: {result.content}, exp: {exp}")
logger.warning(f"response: {result}, exp: {exp}")
response_content = b""
return (
self._interpret_response_line(
await result.read(), # let the caller to decode the msg
response_content, # let the caller decode the msg
result.status,
result.headers,
stream=False,

View file

@ -3,16 +3,189 @@
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
import json
from enum import Enum, auto
from typing import AsyncGenerator, Optional, Tuple
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import TokenCostManager
class OllamaMessageAPI(Enum):
# default
CHAT = auto()
GENERATE = auto()
EMBED = auto()
EMBEDDINGS = auto()
class OllamaMessageBase:
api_type = OllamaMessageAPI.CHAT
def __init__(self, model: str, **additional_kwargs) -> None:
self.model, self.additional_kwargs = model, additional_kwargs
self._image_b64_rms = len("data:image/jpeg;base64,")
@property
def api_suffix(self) -> str:
raise NotImplementedError
def apply(self, messages: list[dict]) -> dict:
raise NotImplementedError
def decode(self, response: OpenAIResponse) -> dict:
return json.loads(response.data.decode("utf-8"))
def get_choice(self, to_choice_dict: dict) -> str:
raise NotImplementedError
def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]:
if "type" in msg:
tpe = msg["type"]
if tpe == "text":
return msg["text"], None
elif tpe == "image_url":
return None, msg["image_url"]["url"][self._image_b64_rms :]
else:
raise ValueError
else:
raise ValueError
class OllamaMessageMeta(type):
registed_message = {}
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
for base in bases:
if issubclass(base, OllamaMessageBase):
api_type = attrs["api_type"]
assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist"
assert isinstance(api_type, OllamaMessageAPI), "api_type not support"
OllamaMessageMeta.registed_message[api_type] = cls
@classmethod
def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]:
return cls.registed_message[input_type]
class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.CHAT
@property
def api_suffix(self) -> str:
return "/chat"
def apply(self, messages: list[dict]) -> dict:
content = messages[0]["content"]
prompts = []
images = []
if isinstance(content, list):
for msg in content:
prompt, image = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
if image:
images.append(image)
else:
prompts.append(content)
messes = []
for prompt in prompts:
if len(images) > 0:
messes.append({"role": "user", "content": prompt, "images": images})
else:
messes.append({"role": "user", "content": prompt})
sends = {"model": self.model, "messages": messes}
sends.update(self.additional_kwargs)
return sends
def get_choice(self, to_choice_dict: dict) -> str:
message = to_choice_dict["message"]
if message["role"] == "assistant":
return message["content"]
else:
raise ValueError
class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.GENERATE
@property
def api_suffix(self) -> str:
return "/generate"
def apply(self, messages: list[dict]) -> dict:
content = messages[0]["content"]
prompts = []
images = []
if isinstance(content, list):
for msg in content:
prompt, image = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
if image:
images.append(image)
else:
prompts.append(content)
if len(images) > 0:
sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
else:
sends = {"model": self.model, "prompt": "\n".join(prompts)}
sends.update(self.additional_kwargs)
return sends
def get_choice(self, to_choice_dict: dict) -> str:
return to_choice_dict["response"]
class OllamaMessageEmbeddings(OllamaMessageBase, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.EMBEDDINGS
@property
def api_suffix(self) -> str:
return "/embeddings"
def apply(self, messages: list[dict]) -> dict:
content = messages[0]["content"]
prompts = [] # NOTE: not support image to embedding
if isinstance(content, list):
for msg in content:
prompt, _ = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
else:
prompts.append(content)
sends = {"model": self.model, "prompt": "\n".join(prompts)}
sends.update(self.additional_kwargs)
return sends
class OllamaMessageEmbed(OllamaMessageEmbeddings, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.EMBED
@property
def api_suffix(self) -> str:
return "/embed"
def apply(self, messages: list[dict]) -> dict:
content = messages[0]["content"]
prompts = [] # NOTE: not support image to embedding
if isinstance(content, list):
for msg in content:
prompt, _ = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
else:
prompts.append(content)
sends = {"model": self.model, "input": prompts}
sends.update(self.additional_kwargs)
return sends
@register_provider(LLMType.OLLAMA)
class OllamaLLM(BaseLLM):
"""
@ -20,83 +193,80 @@ class OllamaLLM(BaseLLM):
"""
def __init__(self, config: LLMConfig):
self.__init_ollama(config)
self.client = GeneralAPIRequestor(base_url=config.base_url)
self.config = config
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self.cost_manager = TokenCostManager()
self.__init_ollama(config)
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.CHAT
@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})
assert assist_msg.get("role", None) == "assistant"
return assist_msg.get("content")
ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
def get_usage(self, resp: dict) -> dict:
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
chunk = chunk.decode(encoding)
return json.loads(chunk)
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
params=self._const_kwargs(messages),
url=self.ollama_message.api_suffix,
params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
raise ValueError
def get_choice_text(self, rsp):
return self.ollama_message.get_choice(rsp)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
stream_resp, _, _ = await self.client.arequest(
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
stream=True,
params=self._const_kwargs(messages, stream=True),
url=self.ollama_message.api_suffix,
params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
stream=True,
)
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
raise ValueError
def _processing_openai_response(self, openai_resp: OpenAIResponse):
resp = self.ollama_message.decode(openai_resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
collected_content = []
usage = {}
async for raw_chunk in stream_resp:
chunk = self._decode_and_load(raw_chunk)
async for raw_chunk in ag_openai_resp:
chunk = self.ollama_message.decode(raw_chunk)
if not chunk.get("done", False):
content = self.get_choice_text(chunk)
content = self.ollama_message.get_choice(chunk)
collected_content.append(content)
log_llm_stream(content)
else:
@ -107,3 +277,55 @@ class OllamaLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@register_provider(LLMType.OLLAMA_GENERATE)
class OllamaGenerate(OllamaLLM):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.GENERATE
@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
@register_provider(LLMType.OLLAMA_EMBEDDINGS)
class OllamaEmbeddings(OllamaLLM):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.EMBEDDINGS
@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}}
@property
def _llama_embedding_key(self) -> str:
return "embedding"
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.ollama_message.api_suffix,
params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
)
return self.ollama_message.decode(resp)[self._llama_embedding_key]
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
def get_choice_text(self, rsp):
return rsp
@register_provider(LLMType.OLLAMA_EMBED)
class OllamaEmbed(OllamaEmbeddings):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.EMBED
@property
def _llama_embedding_key(self) -> str:
return "embeddings"

View file

@ -3,11 +3,11 @@
# @Desc : the unittest of ollama api
import json
from typing import Any, Tuple
from typing import Any, AsyncGenerator, Tuple
import pytest
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
if stream:
class Iterator(object):
async def async_event_generator() -> AsyncGenerator[Any, None]:
events = [
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
]
for event in events:
yield OpenAIResponse(event, {})
async def __aiter__(self):
for event in self.events:
yield event
return Iterator(), None, None
return async_event_generator(), None, None
else:
raw_default_resp = default_resp.copy()
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
return json.dumps(raw_default_resp).encode(), None, None
return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None
@pytest.mark.asyncio