From 29d36948bf31ce337ab3c3f119e268d170a8a314 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 30 Jan 2024 20:19:50 +0800 Subject: [PATCH] rag pipeline --- examples/example.json | 10 -- examples/example.xlsx | Bin 9092 -> 0 bytes examples/rag_pipeline.py | 102 ++++++++++++++++++ examples/rag_search.py | 19 ++++ metagpt/rag/__init__.py | 0 metagpt/rag/engines/__init__.py | 3 + metagpt/rag/engines/simple.py | 48 +++++++++ metagpt/rag/llm.py | 7 ++ metagpt/rag/rankers/__init__.py | 0 metagpt/rag/rankers/base.py | 20 ++++ metagpt/rag/retrievers/__init__.py | 4 + metagpt/rag/retrievers/base.py | 18 ++++ metagpt/rag/retrievers/hybrid.py | 36 +++++++ metagpt/roles/sales.py | 3 +- metagpt/tools/__init__.py | 5 + .../document_store/test_faiss_store.py | 6 +- tests/metagpt/rag/engine/test_simple.py | 67 ++++++++++++ .../rag/retrievers/test_hybrid_retriever.py | 39 +++++++ 18 files changed, 372 insertions(+), 15 deletions(-) delete mode 100644 examples/example.json delete mode 100644 examples/example.xlsx create mode 100644 examples/rag_pipeline.py create mode 100644 examples/rag_search.py create mode 100644 metagpt/rag/__init__.py create mode 100644 metagpt/rag/engines/__init__.py create mode 100644 metagpt/rag/engines/simple.py create mode 100644 metagpt/rag/llm.py create mode 100644 metagpt/rag/rankers/__init__.py create mode 100644 metagpt/rag/rankers/base.py create mode 100644 metagpt/rag/retrievers/__init__.py create mode 100644 metagpt/rag/retrievers/base.py create mode 100644 metagpt/rag/retrievers/hybrid.py create mode 100644 tests/metagpt/rag/engine/test_simple.py create mode 100644 tests/metagpt/rag/retrievers/test_hybrid_retriever.py diff --git a/examples/example.json b/examples/example.json deleted file mode 100644 index 996cbec3b..000000000 --- a/examples/example.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "source": "Which facial cleanser is good for oily skin?", - "output": "ABC cleanser is preferred by many with oily skin." - }, - { - "source": "Is L'Oreal good to use?", - "output": "L'Oreal is a popular brand with many positive reviews." - } -] \ No newline at end of file diff --git a/examples/example.xlsx b/examples/example.xlsx deleted file mode 100644 index 85fda644e2795a30709a406371627ffc2815548d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9092 zcmeHN1y@|zvTod^aSKk+hKAq}2yVfHTd)LocL?qlT!K3x(73xp@DPGS2%&-C4zDva zcg;*@-Y>X!_v*b@uXU;E_Yi!D%@JgU?S#+AB|y_MJk$$zPR zhRU-aG>FHdD$>=LG*E7;mt$%9kP~r(E1pMa#rFYk@{2du(YQseooz!<_&}=~W>nx% zmxduZA4mTnl%bQDG|@>{=kOf?`7=&ZBcqsCnSeCM#$GjU$#pULN-_&VB9Va2?9YRm zz}10f%yQjQfuW`DH6_gzApc0&4DRL>{;ZWHv$x^NCP)r3iq`uJk(*(W3XQfpvx!=> z({-vshci>tlew3s@GIcZ_;&2&iB;s*4n9gnZnE04$kTG&fp}}yP9J?AuTG}es8f~e z7G6RNBamLsvp|}l=}!O^>nI3O6_?$lZ=Wb(!Sm4{J=DuL4D0w5eXHbVEje-VW=bTq zwrp!fgE^AQ0(>k%ekDW0P5fQT`=%5=cEwXDTZmGZ4+Q7H-IkLsaopi98hYQb{k5T8 zA#do%ZM`=KM=87hT=&QTz{3LqK;>_+tk(e3ox*(WISh3eFjyKnncFyXu>YL@2gm8YhqAX}Ll<+)@wnm&?qag7G%x%Dck3fj7!M!Rp_1e)Dh2^nRinL?pdLPhj?riQnT}IxM$*nz( zrJ|`cPhohKPI~4@@&oQ9yA~N5ei20|(bM!m-2p|N6_d+qxH)n4{i@KKM!t97lO{9$ z7gCCMu|&fKos&qW_IG=`GAzN zhSX@e@h-V0Wd<4Eo%!oRxQ|A&J^T33dMk#03YpG96iH4!fmx&l1`;v= zDx8N6$L~0Cw|BBOvA4JW$zK1#3>=KYU|RmWN2QvgLLV5n9pgR>?4IdH2wZgGpoMDe z;h+vfn3o>YJn=hSC1+_d(4JS|K(G(>JRBc%yWqjzK)^riW-X1xL-fM8ITS*&A3qsK zKsDK~q70Hj!pDPt->cZiz|L^SZx%@$V#^`m|9Z*9KpIw1Oe=ZXJRQnyH34?zvU2|T zXawdnqmZqD>tJ2zEsP0J>q>4-V`Kn9KatSt>Pb?K1@l?nQR*QTp7`ktF~;**%$> zW!dH2*6yN;>}e*1;Dxlu8ubqyrg+nC>D9G%swSqny_-8(ClouUo5f}IGAv#LOUQCxcQc~GgeIuC0TwsShH{G`#fv+H zwzB>`LZ5?*w(@E4QZ`ngLSk}LTFE2DytbpqW1%>F*xzB$$RNY$a%uVWAC97_=6O z^dApd=g;tUlc*3DaF4()(i553)Cbz*TQjKT){2xltBv~ctzXfO7s-Ji+eujFe&eGXturop^X@9cE?Eu zZyBSDM&o%X>3s3U<_-!m<7r=%*P=7HNarqBtFN9bOttK$2YtQ9Oz}}%wF%FQLe_jk zNJH+C{@Po>2(9WJ1F2OrJdN^?wQrt*`z*@11NkX0Ad=H*jd(6iQG3cfd?tg0C?VDg z*Q;HKAp`KfVwK=TVS4vF!i2ZH;a=65Mqa2Z^gPSY{M9~Cjs?G8%8*D(e?KdSt+&p1 zD(6H9K~uF)x6SpR0N@g=7k0BUb9;y=v%kY$*aDWr*K@cjdShE&bUyXQKoOW4<7wlM z!d&V}0aBxoqYc2tKElyMyT(>U!U<`jC%SKg-)#3A!XpUc1H}pYe-e)2^1;1hJa&EbQy7?O%wfx*5Tl zZXs#A2GkJNHQ#47hGT5aSqE>U zUA9T~3CE^)5WA*7BDNGsLLl1o?RU$X0#}uv|vwtRH4I z=GJ;M?$r9BVkWZ~BQ;{==3xK9xh?dYwO66Dn{%P*?Uj?OXW^k#%e7nPqrJAK=*0e` z`>1WFqeGTE!AHjTTzOx$9e857co@^O?*0{_d+09Zx3GX-h5`U!{f4;Vt!!^aDJ@{LsBL|XX{8iv-AL-4qi3A;!zIS zt+4ksJ*V5`u9ZTW6o5uCvT=Uid6lm}1>I!gg(3s=g*=9daT#An84p8|V=CfB0d?AC z5I908xS8o5Q4;sgQOi37kLQEUIDg#@5eA*=n2~+qwAnb#>|=@~c%F4hnbIi~#;UVe zs&4eE=yE9O+I4xkw%B2)(8+5NW+mBxTmC}u`I~P<7t*45mQZW2@hl&H1r~-Xv5?Pl znH#G1LwF}>2Kd;=(w`<1#T|H;kd+1oDY8V!KmmmNYM zX>1~RseMoc$yy~kPL`z{ocM3vf5Vv(V9cuGX-$=+au6ax&2sd6Zm^|>63Q0H_o^&K zfK5fT^YcPHKic-kBfH6`4qgSrBGDV&NS&@}4Aw;oOr38mXgm&{>-1Qf)6`|yHrPSR zLl1RzFqF9!35JTZS!w zY7rBg5`yG$v#dtZFOt!<)s~MfS;Aj_dw(M00G+x;b%pXFKI ziAU4I@t+_j!N&7t9UpX|*3x6X_j5kkT+;eR@st6w!fBTdMHNJg;{(*$DCHxu56-(n zG$#s(KUNnKE;@RhOba8uO8g+X?(tbnAuqh`lL;XHa;qa$^W*1nu zVCZ!dhDeUyW2V=~@8-nv@(3k^w(VsvQ5bG#M0sz}{*B1;YaVb6VGtRo{&{ZyjmTUq z%x%p%ew#Uef(5Fh7e&rT;Ky+)ig|Ky$-6mK;GzDMIHmh4Z*qNSWf&n8uu;X6*OVqsX8+;W!Zoz?l!xiz&TbVTcG7~ zB0J4XCJmz}F~z7u<@jk<^rtlXnG}C#ekN`oW!dCtpgEe|6$aa_d&8>V9_{BuVC)B! zYoC{{oaxt$ZFDPN?^DNTRWj@bY1{Ii+3 z^rxG5=X8i&;ysn8P?FIOq|legNM^YNfM5z zaPxx!Z=pj%vGe!< zSzM;Nxs~Yy{FmIx0K*6cbH1LG^aW;l`iu=k1!dnsY)-ez17C5r&VvNt>mbF#3y~wi zaDD?WdK69CZE+#`&?Yv0WbsnaDs=f+FuFx^xmSi4uK_6^dowIe*RC!2D_%eQTL}X7 zne|p#X?D5`Vjr`OYjN?yk~FUOq66YhX4Q_`?}4YRUEqV7!zn2nx@P(?_UzVhPW(mF zuw~T#Cne+b5znEm8@Or&o(t(N-4|)?60=YG;26kzeWPBoewln6l=-Z>Sk5!3)c)%i z;HW~{;*FyiBWkO5@wjP}!BQ4~UBQ!F-Lm#5PN_kQ=nw8E24P5#cM+=7-a%T-D?v3-!tMF1D zu^g}jB2s9Somd~bKYRX!28h>vL=mc(83u<=SFsWycRaSS^I{SIUdTMIH9+7^&|uZ2 zhJ3bSa;N$#F5lZ-2$j?;1FFSdg=4Dwn<;8!dL5J+bCnU<6E>jAHUK4WChiR?xaQjYkO z4uPcmG8V{Hd9`*gm86QjSlUSOKXOj@wiHAv@6P*{c(x4=SeDf3YZ{EthQJT|v=LMf zPn(&ge^{x^D5YIFbtnI-AgV%8kxIOM2KWmvXa7Gr7}~?=cHNYbn~2rohLK z)j#$L*H8uH-krn3xPCQ;2%q6%DBuXgaO z=r^LQq+g`yX{U;bYex{=;lnM|*5+VEiH!(|dgJtcoHH=9hA?6K%KiJx5sQVk$Z3J( zSUSNd6}S8}vz#8Di0ZK^WmYN@<;9Hdyghc|v~q;=x5~?}spdwUkuATzF79XfKI@XS zyC<}TRdFmaTgy^pcc%n$sSiTlo#Hd^^Ha727_|PxD<8TlzJoU>Qt8?PuNf~ zm1lV`ZC$1LpB~jRiVkH@=(;{*+mp|}y#D&7ZSECHW1z5?gFk5f9JUxCirN~9%bT(QT`{2ry8{_8Ly%z`F-LwTw z6vd6_2u(fjqM!$caqpK`2S@cXFzASg(>roGta{Jxe495~I;qMtJy+?T3A$Uq6|4IU zokbtF@>W6h%9hSsta3$f*KJ%OKW#BC5NO)*mLa2nR4l)nwN-#9rdNx^;zj5&CSaQ+B!`O#GZzY=bV#Wd2M?Wg$zwow;0$y#AKPlDbKZ%c1E zjIEk?0aV-!_xFWzpaUh{1O#nH#iSVv0yX#;aFI=efbVrP8kc<5j^glJOX?DiwQtME zCnfz;pRz09gvG=+opildCrF+1k;j*ZARpJ{%`3lBiteX0&%<-vP*-VFgfh?Q^SV%v z?SnAYHPpJ^0KKtdsHlkuF48(H%=ad^L3Z7icJ4R4xYCW`OofT}MPX|B{kM)wB%>H1 zuDSV~#9L-Zkw;1sTJ(i>bcZrx1MU{#+KB@0Qz40C8}UDN~7wx$7zhy1@kW-iZ;Jc`!x>G!&z(-S7hv zT}~f7=68_uIrh-i_V^#zp*Sgmwd`3EVB(?N36<>%?8M#hcBp z(II$PVcFuf$9^Iqhua{m{9QBLfJhm|>6{rV=MdVaZ5^t@hs_|_fgGb#?)}3m=e#v1 z>uZ`WZ?eBPt@|6&9L0#!acjH;z6xi-~nmScM}>X5{%m!N%c&lvfnOG?|KD=4|&-fqa2(%{~QTn?@!dS&_2 z?3%>Uw{%p$bgbToRF&I&v?`_fZvOoG(J1!jhwBy9Ik8O^CrXyMSR}gmXXN5o%s|BB zCqpS7t{Wbm_l=3Pd?a-9HR*Zo!o9JO8=j$9EJ~yt_MDkKzqlDd!5Z2QI7v-_qzCLp zY>Yn>brC={iyNH$I5-@CUFp8d=Gx~97AO8`MEi-$r|K{Rn!`*CgcUT+>`hgi>>Zpr zOzoY_|7gYiuZ{`xuIR*eSSyMk#GQVH<~WC66e6yfC?6oGiWWhOn`yq<0Q_LO7z&B& z8tCgYtvTVQJ|g0mBUgJHICVkcFS0YTYX*`(M1wH1q~|fACZHs0w8f8NE{~PmGapA6 zCC=0)#^+&?a1Y1DsTHGEaO>~N2R0y)OePf4JU6#6uc9R(Z-5t5T%(wdh?2*!T2|%w zkPi)HhZ{8JAAo1uo>lt-@SVe-A{y3-dlA=S1D9pYBo+WcmJti~ZzqC){S7VN2C5Ud z73i1Pexb>%Q-UsAtDg?ERKF3lG;O*Cc4QhDHfr37^R}7<949~N;a1Vxz1_RoyxE_- z@$pnXdikZYrf^{5O9L^|kGk=k6BXC725}VRgJzY&iW+shq?4#Cf^A0`Pb62Wf)O$b}eB9|mrKdYes^u+y$!t*?sCz8R~4)?E~ zXzbwdKSzX_?2nO^*lh>n^H3O{V}~BZWZ5xAKGWnZB=^$(0Fd+EF%u+utPt8qHk@Og z`Uv?H%kA@0Wf!F>B&K-c^Nbz6|HckuW|k#nB)i7D>ff_QL1I1IW|S=9BV-CJ23&Uq)0k=GV-@ z?pR)lC7ynd9c8{Luql6jlSlH1_|e^aOVa0)sdjufDS`D~3k?%w)P6-B4Ua@J-xsfxr$h#M`Y-n z3g@gGKVF#S(ir1thGXX=DPw+Wj2Xslmj{K(r8CX8^#~=Sk#{8?6UlWZ#waAkF{sh0K9lvV$JInql4*j%_7fBhdNYq2^2 diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py new file mode 100644 index 000000000..5b47cec62 --- /dev/null +++ b/examples/rag_pipeline.py @@ -0,0 +1,102 @@ +"""RAG pipeline""" +import asyncio + +import faiss +from llama_index import ( + ServiceContext, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, +) +from llama_index.postprocessor import LLMRerank +from llama_index.retrievers import BM25Retriever, VectorIndexRetriever +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.const import EXAMPLE_PATH +from metagpt.rag.llm import get_default_llm +from metagpt.rag.retrievers import SimpleHybridRetriever +from metagpt.utils.embedding import get_embedding + +DOC_PATH = EXAMPLE_PATH / "data/rag.txt" +QUESTION = "What are key qualities to be a good writer?" +TOPK = 5 + + +def print_result(nodes, extra="retrieve"): + """print retrieve/rerank result""" + print("-" * 50) + print(f"{extra} result") + for i, node in enumerate(nodes): + print(f"{i}. {node.text[:10]}..., {node.score}") + + +async def rag_pipeline(): + """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: + + -------------------------------------------------- + faiss retrieve result + 0. I highly r..., 0.3958844542503357 + 1. I wrote cu..., 0.41629382967948914 + 2. Productivi..., 0.4318419098854065 + 3. Some sort ..., 0.45991092920303345 + -------------------------------------------------- + bm25 retrieve result + 0. I highly r..., 0.19445682103516615 + 1. Some sort ..., 0.18688966233196197 + 2. Productivi..., 0.17071309618829872 + 3. I wrote cu..., 0.15878996566615383 + -------------------------------------------------- + hybrid retrieve result + 0. I highly r..., 0.3958844542503357 + 1. I wrote cu..., 0.41629382967948914 + 2. Productivi..., 0.4318419098854065 + 3. Some sort ..., 0.45991092920303345 + -------------------------------------------------- + llm ranker result + 0. Productivi..., 10.0 + 1. I wrote cu..., 7.0 + 2. I highly r..., 5.0 + """ + # Documents, there are many readers can load documents. + documents = SimpleDirectoryReader(input_files=[DOC_PATH]).load_data() + + # Service Conext, a bundle of resources for llm/embedding/node_parse. + service_context = ServiceContext.from_defaults(llm=get_default_llm(), embed_model=get_embedding()) + + # Nodes, chunks of documents. + node_parser = service_context.node_parser + nodes = node_parser.get_nodes_from_documents(documents) + + # Index-FAISS + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) # dimensions of text-ada-embedding-002 + storage_context = StorageContext.from_defaults(vector_store=vector_store) + vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context) + + # Retriever-FAISS + faiss_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=TOPK) + faiss_retrieve_nodes = await faiss_retriever.aretrieve(QUESTION) + print_result(faiss_retrieve_nodes, extra="faiss retrieve") + + # Retriever-BM25 + bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=TOPK) + bm25_retrieve_nodes = await bm25_retriever.aretrieve(QUESTION) + print_result(bm25_retrieve_nodes, extra="bm25 retrieve") + + # Retriever-Hybrid + hybrid_retriever = SimpleHybridRetriever(faiss_retriever, bm25_retriever) + hybrid_retrieve_nodes = await hybrid_retriever.aretrieve(QUESTION) + print_result(hybrid_retrieve_nodes, extra="hybrid retrieve") + + # Ranker-LLM + llm_ranker = LLMRerank(top_n=TOPK, service_context=service_context) + llm_rank_nodes = llm_ranker.postprocess_nodes(faiss_retrieve_nodes, query_str=QUESTION) + print_result(llm_rank_nodes, extra="llm ranker") + + +async def main(): + """RAG pipeline""" + await rag_pipeline() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/rag_search.py b/examples/rag_search.py new file mode 100644 index 000000000..222573476 --- /dev/null +++ b/examples/rag_search.py @@ -0,0 +1,19 @@ +"""Agent with RAG search""" +import asyncio + +from examples.rag_pipeline import DOC_PATH, QUESTION, TOPK +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.roles import Sales + + +async def search(): + """Agent with RAG search""" + store = SimpleEngine.from_docs(input_files=[DOC_PATH], similarity_top_k=TOPK) + role = Sales(profile="Sales", store=store) + result = await role.run(QUESTION) + logger.info(result) + + +if __name__ == "__main__": + asyncio.run(search()) diff --git a/metagpt/rag/__init__.py b/metagpt/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py new file mode 100644 index 000000000..7b4e37e88 --- /dev/null +++ b/metagpt/rag/engines/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.engines.simple import SimpleEngine + +__all__ = ["SimpleEngine"] diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py new file mode 100644 index 000000000..7532f6620 --- /dev/null +++ b/metagpt/rag/engines/simple.py @@ -0,0 +1,48 @@ +"""Simple Engine.""" +from typing import Optional + +from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex +from llama_index.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.embeddings.base import BaseEmbedding +from llama_index.llms.llm import LLM +from llama_index.query_engine import RetrieverQueryEngine +from llama_index.retrievers import VectorIndexRetriever + +from metagpt.rag.llm import get_default_llm +from metagpt.utils.embedding import get_embedding + + +class SimpleEngine(RetrieverQueryEngine): + """ + SimpleEngine is a search engine that uses a vector index for retrieving documents. + """ + + @classmethod + def from_docs( + cls, + input_dir: str = None, + input_files: list = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + # node parser kwargs + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + # retrieve kwargs + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + ) -> "SimpleEngine": + """This engine is designed to be simple and straightforward""" + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + service_context = ServiceContext.from_defaults( + embed_model=embed_model or get_embedding(), + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + llm=llm or get_default_llm(), + ) + index = VectorStoreIndex.from_documents(documents, service_context=service_context) + retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k) + + return SimpleEngine(retriever=retriever) + + async def asearch(self, content: str, **kwargs) -> str: + """Inplement tools.SearchInterface""" + return await self.aquery(content) diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py new file mode 100644 index 000000000..e67be1416 --- /dev/null +++ b/metagpt/rag/llm.py @@ -0,0 +1,7 @@ +from llama_index.llms import OpenAI + +from metagpt.config2 import config + + +def get_default_llm() -> OpenAI: + return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key) diff --git a/metagpt/rag/rankers/__init__.py b/metagpt/rag/rankers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/rankers/base.py b/metagpt/rag/rankers/base.py new file mode 100644 index 000000000..482fc4aef --- /dev/null +++ b/metagpt/rag/rankers/base.py @@ -0,0 +1,20 @@ +"""Base Ranker.""" + +from abc import abstractmethod +from typing import Optional + +from llama_index import QueryBundle +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import NodeWithScore + + +class RAGRanker(BaseNodePostprocessor): + """inherit from llama_index""" + + @abstractmethod + def _postprocess_nodes( + self, + nodes: list[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> list[NodeWithScore]: + """postprocess nodes.""" diff --git a/metagpt/rag/retrievers/__init__.py b/metagpt/rag/retrievers/__init__.py new file mode 100644 index 000000000..799766870 --- /dev/null +++ b/metagpt/rag/retrievers/__init__.py @@ -0,0 +1,4 @@ +"""init""" +from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever + +__all__ = ["SimpleHybridRetriever"] diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py new file mode 100644 index 000000000..c0291f217 --- /dev/null +++ b/metagpt/rag/retrievers/base.py @@ -0,0 +1,18 @@ +"""Base retriever.""" + + +from abc import abstractmethod + +from llama_index.retrievers import BaseRetriever +from llama_index.schema import NodeWithScore, QueryType + + +class RAGRetriever(BaseRetriever): + """inherit from llama_index""" + + @abstractmethod + async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]: + """retrieve nodes""" + + def _retrieve(self, query: QueryType) -> list[NodeWithScore]: + """retrieve nodes""" diff --git a/metagpt/rag/retrievers/hybrid.py b/metagpt/rag/retrievers/hybrid.py new file mode 100644 index 000000000..e6b526b38 --- /dev/null +++ b/metagpt/rag/retrievers/hybrid.py @@ -0,0 +1,36 @@ +"""Hybrid retriever.""" +from llama_index.schema import QueryType + +from metagpt.rag.retrievers.base import RAGRetriever + + +class SimpleHybridRetriever(RAGRetriever): + """ + SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers. + """ + + def __init__(self, *retrievers): + self.retrievers: list[RAGRetriever] = retrievers + super().__init__() + + async def _aretrieve(self, query: QueryType, **kwargs): + """ + Asynchronously retrieves and aggregates search results from all configured retrievers. + + This method queries each retriever in the `retrievers` list with the given query and + additional keyword arguments. It then combines the results, ensuring that each node is + unique, based on the node's ID. + """ + all_nodes = [] + for retriever in self.retrievers: + nodes = await retriever.aretrieve(query, **kwargs) + all_nodes.extend(nodes) + + # combine all nodes + result = [] + node_ids = set() + for n in all_nodes: + if n.node.node_id not in node_ids: + result.append(n) + node_ids.add(n.node.node_id) + return result diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index bc449b5cd..e5cb12778 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -11,7 +11,6 @@ from typing import Optional from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize, UserRequirement -from metagpt.document_store.base_store import BaseStore from metagpt.roles import Role from metagpt.tools.search_engine import SearchEngine @@ -27,7 +26,7 @@ class Sales(Role): "delivered with the professionalism and courtesy expected of a seasoned sales guide." ) - store: Optional[BaseStore] = Field(default=None, exclude=True) + store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface @model_validator(mode="after") def validate_stroe(self): diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index c1f604df9..8d265e9f3 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum): def __missing__(cls, key): """Default type conversion""" return cls.CUSTOM + + +class SearchInterface: + async def asearch(self, *args, **kwargs): + ... diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index 1a159b413..0c5a55e0f 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -28,7 +28,7 @@ def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int async def test_search_json(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.json") + store = FaissStore(EXAMPLE_PATH / "data/example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) @@ -39,7 +39,7 @@ async def test_search_json(mocker): async def test_search_xlsx(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.xlsx") + store = FaissStore(EXAMPLE_PATH / "data/example.xlsx") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) @@ -50,7 +50,7 @@ async def test_search_xlsx(mocker): async def test_write(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) - store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") + store = FaissStore(EXAMPLE_PATH / "data/example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() assert _faiss_store.storage_context.docstore assert _faiss_store.storage_context.vector_store.client diff --git a/tests/metagpt/rag/engine/test_simple.py b/tests/metagpt/rag/engine/test_simple.py new file mode 100644 index 000000000..4eb1d0b6d --- /dev/null +++ b/tests/metagpt/rag/engine/test_simple.py @@ -0,0 +1,67 @@ +from unittest.mock import AsyncMock + +import pytest + +from metagpt.rag import SimpleEngine + + +class TestSimpleEngineFromDocs: + def test_from_docs(self, mocker): + # Mock dependencies + mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") + mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] + + mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults") + mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") + mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever") + + # Setup + input_dir = "test_dir" + input_files = ["test_file1", "test_file2"] + embed_model = mocker.MagicMock() + llm = mocker.MagicMock() + chunk_size = 100 + chunk_overlap = 10 + similarity_top_k = 5 + + # Execute + engine = SimpleEngine.from_docs( + input_dir=input_dir, + input_files=input_files, + embed_model=embed_model, + llm=llm, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + similarity_top_k=similarity_top_k, + ) + + # Assertions + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_service_context.assert_called_once_with( + embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm + ) + mock_vector_store_index.assert_called_once_with( + ["document1", "document2"], service_context=mock_service_context.return_value + ) + mock_vector_index_retriever.assert_called_once_with( + index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k + ) + assert isinstance(engine, SimpleEngine) + + @pytest.mark.asyncio + async def test_asearch_calls_aquery(self, mocker): + # Mock + test_query = "test query" + expected_result = "expected result" + mock_aquery = AsyncMock(return_value=expected_result) + + # Setup + engine = SimpleEngine(retriever=mocker.MagicMock()) + engine.aquery = mock_aquery + + # Execute + result = await engine.asearch(test_query) + + # Assertions + mock_aquery.assert_called_once_with(test_query) + assert result == expected_result diff --git a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py new file mode 100644 index 000000000..62d976ba2 --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py @@ -0,0 +1,39 @@ +from unittest.mock import AsyncMock + +import pytest +from llama_index.schema import NodeWithScore, TextNode + +from metagpt.rag.retrievers import SimpleHybridRetriever + + +class TestSimpleHybridRetriever: + @pytest.mark.asyncio + async def test_aretrieve(self): + question = "test query" + + # Create mock retrievers + mock_retriever1 = AsyncMock() + mock_retriever1.aretrieve.return_value = [ + NodeWithScore(node=TextNode(id_="1"), score=1.0), + NodeWithScore(node=TextNode(id_="2"), score=0.95), + ] + + mock_retriever2 = AsyncMock() + mock_retriever2.aretrieve.return_value = [ + NodeWithScore(node=TextNode(id_="2"), score=0.95), + NodeWithScore(node=TextNode(id_="3"), score=0.8), + ] + + # Instantiate the SimpleHybridRetriever with the mock retrievers + hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2) + + # Call the _aretrieve method + results = await hybrid_retriever._aretrieve(question) + + # Check if the results are as expected + assert len(results) == 3 # Should be 3 unique nodes + assert set(node.node.node_id for node in results) == {"1", "2", "3"} + + # Check if the scores are correct (assuming you want the highest score) + node_scores = {node.node.node_id: node.score for node in results} + assert node_scores["2"] == 0.95