mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-04 21:32:38 +02:00
Merge pull request #1544 from EvensXia/fix_metagpt_from_evensxia
fix ollama_api to add llava support
This commit is contained in:
commit
9db0874102
6 changed files with 345 additions and 130 deletions
|
|
@ -15,8 +15,8 @@ async def main():
|
|||
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
|
||||
assert "true" in res.lower()
|
||||
res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
|
||||
assert ("true" in res.lower()) or ("invoice" in res.lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -26,7 +26,10 @@ class LLMType(Enum):
|
|||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
OLLAMA = "ollama" # /chat at ollama api
|
||||
OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
|
||||
OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api
|
||||
OLLAMA_EMBED = "ollama.embed" # /embed at ollama api
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
MOONSHOT = "moonshot"
|
||||
|
|
@ -105,7 +108,8 @@ class LLMConfig(YamlModel):
|
|||
root_config_path = CONFIG_ROOT / "config2.yaml"
|
||||
if root_config_path.exists():
|
||||
raise ValueError(
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n"
|
||||
f"the former will overwrite the latter. This may cause unexpected result.\n"
|
||||
)
|
||||
elif repo_config_path.exists():
|
||||
raise ValueError(f"Please set your API key in {repo_config_path}")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import time
|
|||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
|
|
@ -121,7 +122,7 @@ def logfmt(props):
|
|||
|
||||
|
||||
class OpenAIResponse:
|
||||
def __init__(self, data, headers):
|
||||
def __init__(self, data: Union[bytes, Any], headers: dict):
|
||||
self._headers = headers
|
||||
self.data = data
|
||||
|
||||
|
|
@ -320,49 +321,6 @@ class APIRequestor:
|
|||
resp, got_stream = self._interpret_response(result, stream)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
headers,
|
||||
files,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: Literal[False] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[OpenAIResponse, bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
|
|
@ -438,8 +396,8 @@ class APIRequestor:
|
|||
"X-LLM-Client-User-Agent": json.dumps(ua),
|
||||
"User-Agent": user_agent,
|
||||
}
|
||||
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
if self.api_key:
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
|
||||
if self.organization:
|
||||
headers["LLM-Organization"] = self.organization
|
||||
|
|
|
|||
|
|
@ -3,25 +3,24 @@
|
|||
# @Desc : General Async API for http-based LLM model
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
|
||||
from typing import AsyncGenerator, Iterator, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.general_api_base import APIRequestor
|
||||
from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse
|
||||
|
||||
|
||||
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
|
||||
def parse_stream_helper(line: bytes) -> Optional[bytes]:
|
||||
if line and line.startswith(b"data:"):
|
||||
if line.startswith(b"data: "):
|
||||
# SSE event may be valid when it contain whitespace
|
||||
# SSE event may be valid when it contains whitespace
|
||||
line = line[len(b"data: ") :]
|
||||
else:
|
||||
line = line[len(b"data:") :]
|
||||
if line.strip() == b"[DONE]":
|
||||
# return here will cause GeneratorExit exception in urllib3
|
||||
# and it will close http connection with TCP Reset
|
||||
# Returning None to indicate end of stream
|
||||
return None
|
||||
else:
|
||||
return line
|
||||
|
|
@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
|
|||
|
||||
class GeneralAPIRequestor(APIRequestor):
|
||||
"""
|
||||
usage
|
||||
Usage example:
|
||||
# full_url = "{base_url}{url}"
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
|
|
@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
)
|
||||
"""
|
||||
|
||||
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
|
||||
# just do nothing to meet the APIRequestor process and return the raw data
|
||||
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
|
||||
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
|
||||
"""
|
||||
Process and return the response data wrapped in OpenAIResponse.
|
||||
|
||||
return rbody
|
||||
Args:
|
||||
rbody (bytes): The response body.
|
||||
rcode (int): The response status code.
|
||||
rheaders (dict): The response headers.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
OpenAIResponse: The response data wrapped in OpenAIResponse.
|
||||
"""
|
||||
return OpenAIResponse(rbody, rheaders)
|
||||
|
||||
def _interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
|
||||
"""
|
||||
Interpret a synchronous response.
|
||||
|
||||
Args:
|
||||
result (requests.Response): The response object.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
|
||||
"""
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
return (
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
), True
|
||||
(
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
),
|
||||
True,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
result.content, # let the caller to decode the msg
|
||||
result.content, # let the caller decode the msg
|
||||
result.status_code,
|
||||
result.headers,
|
||||
stream=False,
|
||||
|
|
@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
|
||||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
|
||||
"""
|
||||
Interpret an asynchronous response.
|
||||
|
||||
Args:
|
||||
result (aiohttp.ClientResponse): The response object.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
|
||||
"""
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and (
|
||||
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
|
||||
):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
), True
|
||||
(
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
),
|
||||
True,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
await result.read()
|
||||
response_content = await result.read()
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise TimeoutError("Request timed out") from e
|
||||
except aiohttp.ClientError as exp:
|
||||
logger.warning(f"response: {result.content}, exp: {exp}")
|
||||
logger.warning(f"response: {result}, exp: {exp}")
|
||||
response_content = b""
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
await result.read(), # let the caller to decode the msg
|
||||
response_content, # let the caller decode the msg
|
||||
result.status,
|
||||
result.headers,
|
||||
stream=False,
|
||||
|
|
|
|||
|
|
@ -3,16 +3,189 @@
|
|||
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
|
||||
|
||||
import json
|
||||
from enum import Enum, auto
|
||||
from typing import AsyncGenerator, Optional, Tuple
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
class OllamaMessageAPI(Enum):
|
||||
# default
|
||||
CHAT = auto()
|
||||
GENERATE = auto()
|
||||
EMBED = auto()
|
||||
EMBEDDINGS = auto()
|
||||
|
||||
|
||||
class OllamaMessageBase:
|
||||
api_type = OllamaMessageAPI.CHAT
|
||||
|
||||
def __init__(self, model: str, **additional_kwargs) -> None:
|
||||
self.model, self.additional_kwargs = model, additional_kwargs
|
||||
self._image_b64_rms = len("data:image/jpeg;base64,")
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, response: OpenAIResponse) -> dict:
|
||||
return json.loads(response.data.decode("utf-8"))
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
if "type" in msg:
|
||||
tpe = msg["type"]
|
||||
if tpe == "text":
|
||||
return msg["text"], None
|
||||
elif tpe == "image_url":
|
||||
return None, msg["image_url"]["url"][self._image_b64_rms :]
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class OllamaMessageMeta(type):
|
||||
registed_message = {}
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
super().__init__(name, bases, attrs)
|
||||
for base in bases:
|
||||
if issubclass(base, OllamaMessageBase):
|
||||
api_type = attrs["api_type"]
|
||||
assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist"
|
||||
assert isinstance(api_type, OllamaMessageAPI), "api_type not support"
|
||||
OllamaMessageMeta.registed_message[api_type] = cls
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]:
|
||||
return cls.registed_message[input_type]
|
||||
|
||||
|
||||
class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.CHAT
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/chat"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = []
|
||||
images = []
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
else:
|
||||
prompts.append(content)
|
||||
messes = []
|
||||
for prompt in prompts:
|
||||
if len(images) > 0:
|
||||
messes.append({"role": "user", "content": prompt, "images": images})
|
||||
else:
|
||||
messes.append({"role": "user", "content": prompt})
|
||||
sends = {"model": self.model, "messages": messes}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
message = to_choice_dict["message"]
|
||||
if message["role"] == "assistant":
|
||||
return message["content"]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.GENERATE
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = []
|
||||
images = []
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
else:
|
||||
prompts.append(content)
|
||||
if len(images) > 0:
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
|
||||
else:
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts)}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
return to_choice_dict["response"]
|
||||
|
||||
|
||||
class OllamaMessageEmbeddings(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.EMBEDDINGS
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/embeddings"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = [] # NOTE: not support image to embedding
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, _ = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompts.append(content)
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts)}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
|
||||
class OllamaMessageEmbed(OllamaMessageEmbeddings, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.EMBED
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/embed"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = [] # NOTE: not support image to embedding
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, _ = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompts.append(content)
|
||||
sends = {"model": self.model, "input": prompts}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA)
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
|
|
@ -20,83 +193,80 @@ class OllamaLLM(BaseLLM):
|
|||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_ollama(config)
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self.cost_manager = TokenCostManager()
|
||||
self.__init_ollama(config)
|
||||
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.CHAT
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
assert assist_msg.get("role", None) == "assistant"
|
||||
return assist_msg.get("content")
|
||||
ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
|
||||
self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
|
||||
|
||||
def get_usage(self, resp: dict) -> dict:
|
||||
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
|
||||
|
||||
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
|
||||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
params=self._const_kwargs(messages),
|
||||
url=self.ollama_message.api_suffix,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
if isinstance(resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(resp)
|
||||
elif isinstance(resp, OpenAIResponse):
|
||||
return self._processing_openai_response(resp)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_choice_text(self, rsp):
|
||||
return self.ollama_message.get_choice(rsp)
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
params=self._const_kwargs(messages, stream=True),
|
||||
url=self.ollama_message.api_suffix,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(resp)
|
||||
elif isinstance(resp, OpenAIResponse):
|
||||
return self._processing_openai_response(resp)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def _processing_openai_response(self, openai_resp: OpenAIResponse):
|
||||
resp = self.ollama_message.decode(openai_resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for raw_chunk in stream_resp:
|
||||
chunk = self._decode_and_load(raw_chunk)
|
||||
async for raw_chunk in ag_openai_resp:
|
||||
chunk = self.ollama_message.decode(raw_chunk)
|
||||
|
||||
if not chunk.get("done", False):
|
||||
content = self.get_choice_text(chunk)
|
||||
content = self.ollama_message.get_choice(chunk)
|
||||
collected_content.append(content)
|
||||
log_llm_stream(content)
|
||||
else:
|
||||
|
|
@ -107,3 +277,55 @@ class OllamaLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA_GENERATE)
|
||||
class OllamaGenerate(OllamaLLM):
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.GENERATE
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA_EMBEDDINGS)
|
||||
class OllamaEmbeddings(OllamaLLM):
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.EMBEDDINGS
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}}
|
||||
|
||||
@property
|
||||
def _llama_embedding_key(self) -> str:
|
||||
return "embedding"
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.ollama_message.api_suffix,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
return self.ollama_message.decode(resp)[self._llama_embedding_key]
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
def get_choice_text(self, rsp):
|
||||
return rsp
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA_EMBED)
|
||||
class OllamaEmbed(OllamaEmbeddings):
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.EMBED
|
||||
|
||||
@property
|
||||
def _llama_embedding_key(self) -> str:
|
||||
return "embeddings"
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
# @Desc : the unittest of ollama api
|
||||
|
||||
import json
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, AsyncGenerator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
|
|
@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}}
|
|||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def async_event_generator() -> AsyncGenerator[Any, None]:
|
||||
events = [
|
||||
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
|
||||
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
|
||||
]
|
||||
for event in events:
|
||||
yield OpenAIResponse(event, {})
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
return Iterator(), None, None
|
||||
return async_event_generator(), None, None
|
||||
else:
|
||||
raw_default_resp = default_resp.copy()
|
||||
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
|
||||
return json.dumps(raw_default_resp).encode(), None, None
|
||||
return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue