tested for embeddings/embed

This commit is contained in:
EvensXia 2024-10-30 09:47:10 +08:00
parent 187e512547
commit f2aa4e3f9d
3 changed files with 49 additions and 14 deletions

View file

@ -14,9 +14,10 @@ async def main():
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
assert ("true" in res.lower()) or ("invoice" in res.lower())
encode_image(invoice_path)
# res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
await llm.aask(msg="hello")
# assert ("true" in res.lower()) or ("invoice" in res.lower())
if __name__ == "__main__":

View file

@ -28,7 +28,8 @@ class LLMType(Enum):
AZURE = "azure"
OLLAMA = "ollama" # /chat at ollama api
OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
OLLAMA_EMBEDDING = "ollama.embeddings" # /embeddings at ollama api
OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api
OLLAMA_EMBED = "ollama.embed" # /embeddings at ollama api
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"

View file

@ -20,6 +20,7 @@ class OllamaMessageAPI(Enum):
CHAT = auto()
GENERATE = auto()
EMBED = auto()
EMBEDDINGS = auto()
class OllamaMessageBase:
@ -141,24 +142,50 @@ class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
return to_choice_dict["response"]
class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.EMBED
class OllamaMessageEmbeddings(OllamaMessageBase, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.EMBEDDINGS
@property
def api_suffix(self) -> str:
return "/embeddings"
def apply(self, messages: list[dict]) -> dict:
prompts = []
for msg in messages:
prompt, _ = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
content = messages[0]["content"]
prompts = [] # NOTE: not support image to embedding
if isinstance(content, list):
for msg in content:
prompt, _ = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
else:
prompts.append(content)
sends = {"model": self.model, "prompt": "\n".join(prompts)}
sends.update(self.additional_kwargs)
return sends
class OllamaMessageEmbed(OllamaMessageEmbeddings, metaclass=OllamaMessageMeta):
api_type = OllamaMessageAPI.EMBED
@property
def api_suffix(self) -> str:
return "/embed"
def apply(self, messages: list[dict]) -> dict:
content = messages[0]["content"]
prompts = [] # NOTE: not support image to embedding
if isinstance(content, list):
for msg in content:
prompt, _ = self._parse_input_msg(msg)
if prompt:
prompts.append(prompt)
else:
prompts.append(content)
sends = {"model": self.model, "input": prompts}
sends.update(self.additional_kwargs)
return sends
@register_provider(LLMType.OLLAMA)
class OllamaLLM(BaseLLM):
"""
@ -263,11 +290,11 @@ class OllamaGenerate(OllamaLLM):
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
@register_provider(LLMType.OLLAMA_EMBEDDING)
@register_provider(LLMType.OLLAMA_EMBEDDINGS)
class OllamaEmbeddings(OllamaLLM):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.EMBED
return OllamaMessageAPI.EMBEDDINGS
@property
def _llama_api_kwargs(self) -> dict:
@ -277,7 +304,6 @@ class OllamaEmbeddings(OllamaLLM):
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.ollama_message.api_suffix,
headers=self.headers,
params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
)
@ -288,3 +314,10 @@ class OllamaEmbeddings(OllamaLLM):
def get_choice_text(self, rsp):
return rsp
@register_provider(LLMType.OLLAMA_EMBED)
class OllamaEmbed(OllamaLLM):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.EMBED