mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
update ollama
This commit is contained in:
parent
fdb834674d
commit
a5f33e2d51
4 changed files with 60 additions and 28 deletions
|
|
@ -15,8 +15,8 @@ async def main():
|
|||
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
|
||||
assert "true" in res.lower()
|
||||
res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
|
||||
assert ("true" in res.lower()) or ("invoice" in res.lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -106,8 +106,8 @@ class LLMConfig(YamlModel):
|
|||
root_config_path = CONFIG_ROOT / "config2.yaml"
|
||||
if root_config_path.exists():
|
||||
raise ValueError(
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {
|
||||
repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n"
|
||||
f"the former will overwrite the latter. This may cause unexpected result.\n"
|
||||
)
|
||||
elif repo_config_path.exists():
|
||||
raise ValueError(f"Please set your API key in {repo_config_path}")
|
||||
|
|
|
|||
|
|
@ -396,8 +396,8 @@ class APIRequestor:
|
|||
"X-LLM-Client-User-Agent": json.dumps(ua),
|
||||
"User-Agent": user_agent,
|
||||
}
|
||||
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
if self.api_key:
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
|
||||
if self.organization:
|
||||
headers["LLM-Organization"] = self.organization
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class OllamaMessageBase:
|
|||
|
||||
def __init__(self, model: str, **additional_kwargs) -> None:
|
||||
self.model, self.additional_kwargs = model, additional_kwargs
|
||||
self._image_b64_rms = len("data:image/jpeg;base64,")
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
|
|
@ -38,15 +39,16 @@ class OllamaMessageBase:
|
|||
def decode(self, response: OpenAIResponse) -> dict:
|
||||
return json.loads(response.data.decode("utf-8"))
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
if "role" in msg:
|
||||
return msg["content"], None
|
||||
elif "type" in msg:
|
||||
if "type" in msg:
|
||||
tpe = msg["type"]
|
||||
if tpe == "text":
|
||||
return msg["text"], None
|
||||
elif tpe == "image_url":
|
||||
return None, msg["image_url"]["url"]
|
||||
return None, msg["image_url"]["url"][self._image_b64_rms :]
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
|
|
@ -84,18 +86,35 @@ class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
|||
return "/chat"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = []
|
||||
images = []
|
||||
for msg in messages:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
else:
|
||||
prompts.append(content)
|
||||
messes = []
|
||||
for prompt in prompts:
|
||||
if len(images) > 0:
|
||||
messes.append({"role": "user", "content": "\n".join(prompts), "images": images})
|
||||
else:
|
||||
messes.append({"role": "user", "content": "\n".join(prompts)})
|
||||
sends = {"model": self.model, "messages": messes}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
message = to_choice_dict["message"]
|
||||
if message["role"] == "assistant":
|
||||
return message["content"]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.GENERATE
|
||||
|
|
@ -104,6 +123,29 @@ class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
|
|||
def api_suffix(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
content = messages[0]["content"]
|
||||
prompts = []
|
||||
images = []
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
else:
|
||||
prompts.append(content)
|
||||
if len(images) > 0:
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
|
||||
else:
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts)}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
def get_choice(self, to_choice_dict: dict) -> str:
|
||||
return to_choice_dict["response"]
|
||||
|
||||
|
||||
class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.EMBED
|
||||
|
|
@ -137,13 +179,6 @@ class OllamaLLM(BaseLLM):
|
|||
self.cost_manager = TokenCostManager()
|
||||
self.__init_ollama(config)
|
||||
|
||||
def _get_headers(self):
|
||||
return (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {"Authorization": f"Bearer {self.config.api_key}"}
|
||||
)
|
||||
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.CHAT
|
||||
|
|
@ -158,7 +193,6 @@ class OllamaLLM(BaseLLM):
|
|||
self.pricing_plan = self.model
|
||||
ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
|
||||
self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
|
||||
self.headers = self._get_headers()
|
||||
|
||||
def get_usage(self, resp: dict) -> dict:
|
||||
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
|
||||
|
|
@ -167,7 +201,6 @@ class OllamaLLM(BaseLLM):
|
|||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.ollama_message.api_suffix,
|
||||
headers=self.headers,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
|
|
@ -185,7 +218,6 @@ class OllamaLLM(BaseLLM):
|
|||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.ollama_message.api_suffix,
|
||||
headers=self.headers,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
stream=True,
|
||||
|
|
@ -210,7 +242,7 @@ class OllamaLLM(BaseLLM):
|
|||
chunk = self.ollama_message.decode(raw_chunk)
|
||||
|
||||
if not chunk.get("done", False):
|
||||
content = self.get_choice_text(chunk)
|
||||
content = self.ollama_message.get_choice(chunk)
|
||||
collected_content.append(content)
|
||||
log_llm_stream(content)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue