From 4fcf7247970614e37eaa33801db5943a287995b4 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 19 Jan 2024 17:37:12 +0800 Subject: [PATCH] replace langchain with llama-index --- .gitignore | 6 ++ examples/example.faiss | Bin 12333 -> 0 bytes examples/example.pkl | Bin 624 -> 0 bytes examples/search_kb.py | 4 +- metagpt/document.py | 27 ++++---- metagpt/document_store/base_store.py | 4 +- metagpt/document_store/faiss_store.py | 58 ++++++++++++----- metagpt/memory/memory2.py | 22 +++++++ metagpt/memory/memory_network.py | 18 ++++++ metagpt/memory/memory_storage.py | 20 +----- metagpt/memory/schema.py | 61 ++++++++++++++++++ metagpt/roles/role.py | 8 --- metagpt/utils/embedding.py | 6 +- requirements.txt | 6 +- .../document_store/test_faiss_store.py | 6 +- 15 files changed, 175 insertions(+), 71 deletions(-) delete mode 100644 examples/example.faiss delete mode 100644 examples/example.pkl create mode 100644 metagpt/memory/memory2.py create mode 100644 metagpt/memory/memory_network.py create mode 100644 metagpt/memory/schema.py diff --git a/.gitignore b/.gitignore index 4b522674e..51baa132e 100644 --- a/.gitignore +++ b/.gitignore @@ -154,6 +154,11 @@ key.yaml data data.ms examples/nb/ +examples/default__vector_store.json +examples/docstore.json +examples/graph_store.json +examples/image__vector_store.json +examples/index_store.json .chroma *~$* workspace/* @@ -168,6 +173,7 @@ output tmp.png .dependencies.json tests/metagpt/utils/file_repo_git +tests/data/rsp_cache.json *.tmp *.png htmlcov diff --git a/examples/example.faiss b/examples/example.faiss deleted file mode 100644 index 58094619004ac7b01cf52e596e1bb4bf254f4322..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12333 zcmXw<2V9SB)W=IhL&GS9h*HUjP~F$LQ<0U7lD$1t+{@(spc!C}J<_GDXs zXqd4^N-T=Sz4Nja-4QR~S$bn8tvdu!R_j=Z{&NgJvjQ*w>%!0U(#K}?bx>oG13wTp z5Prt3Ml^rMJmq!hpdSowPmpDzQMkL^66jhS0-qdH zAhLQA!l%AamfaqYy6@p*d_$o1?@i1gZ8tibHc%r*y_f1ohJzjNDzk%^m7khE7}47d z2lcbU)eb+|f}7E7asO3(%G)sTx>W);1LL6a8MEpihs5$Z zJm#ViCT#i)AD5Ryv91M{Osi&nA_s_jL7(&IF#PBq40Y4Nj~?yO)1sps(R(|dUs;Cd ze(mCVIT>KkaFuvg=ya+W+JA`V3pY>4dBtaW-{2yE^*>;H^Vy&Nhlw(HX4jRykgc5ytw5E zV|jUOKD_sJRF9Snz*bi8dG5^m^1m7_mZ28lqAj+tY0nI{<((Da!|w8wu}?6@=LFU1 zn0QU~(fp2+8orVq|M8b*eVHm(Ixd2OkOfrd{V0C_a?E7dV7wUDEdIn!x%Fr3hHR~? znS8Yd#xCE_ioRT_tL5VUf7rkk8c=vS13s3I#JuiyxbZ9k{T{D4atRJsjfeG7m!bMe z448U%Qs~*`steAl$$1?--={h5xYkPbyD|!#*S;$FJ7@zv-%OakzlHq5Ib77QAWjJd zQN!3z$Dro44;nRcg9X_fX4{9soS_H7>djhgyZ1IYeA$g=5dk>3XH&WN_uJq)W<53z zOJS!HFY+npj`EZrAz06S5c^Ya3(|YVYf$Fj4(cuY#pfnj6=ikMQwJRHti~x?us`r0 zP`%*!fv0R;M}u6iGk9un2exsuv6}Px9c$EdFwMh1?7H(Xh(2gKMGwb& zmGPJNL-6a}8BEVH6x|*2_#&&(P<0_6RwtCf^7gOkJ}xQcpL@WOnoaH-!KsHB^*{Xg z!LOiM{q>4}hgSTreK+|?=aD!w;|-(dfuI?oLJx3p#co{sua88tE}d!Fon~e~Y__=u zTMv9^qxCjQKGm)GaQmqPY;Dq)ue{VAL_JSUZdK5| zdNIa_TLbkvr1kJn7u+Bwd+70GP8+JBRm#k_9eN*G`#MVq@Ys^ltSk{)4E~wnq zh5xkP!&jT{=2R=Z=CTXU&0Yy7HO?c=84GRUh(hN)eaB$Wp7X(EzoSB$f*qHf6WRgs z?!ijssU3`uGs4a32XUJgvX(Ovczodv7`S3BZewrw<2^Uf@NFVqs)&SB*MDN)A5GDH z>VL2{qY-ls*viv2r@)9s@r=%gzuA!p)ZaMNW;6Z?oXsm!OTaYr9t^n`i}`o=GI}Q1 zwyp_Mjo6Uwui3#Xz47m$`fAj+@ofjK+lEK9uDA_$2g&@v!m+R^%0Mr z(2Hg-Q0NysH@jAGIiUwM0Ez+37>*4IrhoZM! zJ3RG=_nRJK%R66~jcpO5??Qh$&9wY}@p@d)doD}x^};5Z?r?8foX{9{{XsI%fBy$2 zug-&H$F3+9jD&70D$riX0rxfT#5UY6WjpWILgO#XTtDARK%v9;g6D#$1LLS1mIXs*(tmJxk*1umI0-(FJjH3oxOe^^1w*?ssw<=U=K)tV!1bS|awORXhisq4 z3w1_g-_$1TMU$>F`Gc48UjO`Ej1J ze48}mauMJD#1<};y7T(GC&B^zA?n7u9Ce1oqQ=mxc>w-;_?-o<94F3|g+0^c<>rQ} zn3;F(&-sJfUHMyxhx%rtZ?corRGVbcz=pcz`2y{QG}u9}30hh^ch@+Lg5 z>@IWDj*)04@ny|DB+uX}M&s!GvKjR*+}KeAsjY4K#I2`TR@PC>es0OU{E}fs$YCUH zh71;iJ=0~J@TND-W?Rzum)Ns323}6cz)-&lIKe1U^73l|kNXW#S7?NTh2e0VyYv!1 z$!&!hL*Gh)CA#poxd$#>wwV_{>m%P?Vu|}s+H>+I-r0b$xQQCvVS}&SV`U=mka2*m zoLE7<7zg#Hrb#>RuEdL5VqC=M(p<47Zh?x>N9zY?72z`THI-MJw*fFq!Ig)uqmJcCu$psFcpt>B zp9&44&4KPiB38k>OLl^+G(+Jf=Xa$*)Za|raQl8F&xar5uJJ?<#tZhHg@jE(T+G(T z6nmcWcnho>VuGZ%OnB;?8&^qZi;?G>s*J56?N8%gGZ*owo|Smv?>tWZf~13>zttc0 zYtup7x2a0HiK7ZluykVxD3Xes(-X(anl9_+gu%R=;SJGe;7049^6XEVaw*k z3O#2TkB!k>HIz4}+X?-YY4$kjGxIq7fKzX(#5Yo<;}_=sc>@YwF>!tBS`)bo`@J+0 z{?85^*u_Z$;m9|;y3h6gW2@4M#vn}*WoVn`o zr4252vaVxN@*tkkOlAF}HbF^*hhot1k3zf$UuPN0{z=Q>o@WtrT)GkF6?;M3_N!rE z&w7~idM^6RJIns!GfsM5@ZozFQthPHyEhO|Y9odoW2AFHe~&@~2751n)~mYVvH^uW z_(8T*lyQsOP3nx%H4CsLGadzl&}-7`Da%p$P$=rhMK4nQWb#Z`G2au*=3(>fp|XB= zUs!o*KQHr72j~708daLGhISQbHdKWfADhTKN8aM3@krNT-$$*(XMfGrQH5KP`jYtA z04}f6qkdbAL+iQV(lT@Pq48t*b>5PH&gl&nH}@f(0rA=$F&FZjuh#saeVOv$r!L(~ z5^u0?J`!&br~e|oUCY{dg|P#;OfWdq{9XY*+mfVnzgi28RVrJwfRMxpST*>HU@r2a z0vS>F9NzxCVX>Vubs*Y4PB(!5IF%K?$o05 zf5Pdy>b#Z%@J2`it~c1PsFTJqIuEwc|0EYY={`8pHLF7e7Mx3h$U7UTZttaMjTc~b zp$d~~jTwEON?K0ynS%#wnyIJ1ny572xXU;dJGStu^GNb^POPY$zM+6%58@Mk;>{7J zd%8PY`}7pjT!DYm&pNIrR%YaPDD*kO_pEDChCdSj;rmS!0L`Z&`4tjlC}LjbcJjjn ziyd%2_p9(e{MPOXdBrqh;6~tjzawu}brk1MFA&-n4TRk~fHs z$H*)VUqe{HrT8lG@1^O_e%7 z9<^+mv_ct+$!HGrre8XNy*8^f5 z;WhZSd9iY3Zx-Irc8CAshJvn#kxH5azr6dxHM0xga^VTHjWc4z&VmaVabW>vBffB% z5u^Kq8@6WfqRBCE>efbNUq*bTlAdwOZ$MlrSWuaTJ@EN~DUA3Tv<@fXhwmNv!xf#G zc9#-((Dej+;jAkYLrUb;oO~Svo{i>h-(N*Gp&W?~6v|)lXw+zIV&9yyekP{)>8g~` zkahu_v=!_g>f+M9#e8yFr0bijW#n!CLN9rM+j#mu6~*j|dlCKnPS2UI)iH-8%jr-v zb08k;(oA?IQWn69j{9&x-^Yqy&@IY*8Zh)gceIyB6Mt*q^w7SLIVF#epP9m`w;6e% zLK?&J$6mq{MqyIz$+)`f^wHlV^b13-_vB}iq8QB?w%Aq516yqpx}ypmCI02k0VdQN za|LrL+TUY&af{C`VX1eB{q5NEn2mMUDqpw-AN5`x%pWBdR{(#E~&TdJW_UFzeAK zNKA2~x!r-Z(+B4Ys+{HjWOtY#bhkP+)y$%_} z{WCB4l9(mXH)pP*v12XnAHSa%&Z>@CQ-0Kgi%wfP&74G<0h-A};E9P9*DSP#Z^jYm z+eAy`4UrEhC)Kig&(FEgJRr#C;t^2{y3dHDqi?6-j8wt(nW=sL><&GjdUp#haSj<%2Q1-Kkus0 zCH{PVcg(t~Ej$(fbG^(qo&`?*i=@xu9AwH=z$OgGJ4GXCUJTX!(slS(6Ur|JsL;8S zDK`G{LufZv{O2b$6DWtewyrpVqn|rLL>qTDxUna=O*Dh{+Oz4pRg72)m%mp|W(*Br z$&SYacgs$DY9VFjMAqfyG$wS4@)d~Qr?U|CDcCr>06Z3V0n#;0ZGC{X+5$lP8^Jyx zGP-tychuXnSlrcdOzgNs_UwB59+KY(KSs(;EbLSk3riizYqu9c+MgI8u8`k+YDu#h z39mbN17$w!nAl3}4Y=5|h)hg#D!tRo$2M(`BKdQH=t;3Nt9TI%w9mnq1qMv-O$ApP zHBqTP+-c8B@TpD(kqs}m)lvDl%}Q2%CspgPMZsu2eX-NSSL_y}`hj50!^1jC#O>&7 zf8V)uV?FijnJq|KPg$-n(@8Pqw2NWG)hnFXN~LUxL)MzA-hU48S_xp@&mg2d6x-GJ zBJ!kLVIKTIjvvzrDNUy0T znz@2u;A`JJu?yoHXS*VCcO8q+E>p3$(~!v@xKwb5M_FmoE^{8x?n9G5Em42u0W@saUncHjrG_0SPkq5V7Gt2F=Qg(G!X-xY zjHK)6H2VstnU?z;*;(+Vb&OQ3n^Zt~i8Qz`<*ThAGCkE)uqHnKb4%=<$ma$!k$30} z`RT7)k!Fo%EFV*h?lX~zyS#kP#C(Wc@Hu!p44WSg!`JVY8U|hmsu||5wo-!{Xdu-W zg)XM%RPZhD(%2yDL1LaI@=S$xs+4`Epr@=!d3dKnysc0MS1G$7c?SNrt6_RM9k}o$ zp~b|3s%5|ov6sVl4pEd#G=cU33e6LUormeiLy%z6RHck5_J-U%DG#qJrL<>O+3w?s zIJW!{P=4c7H|X|izu3>+17F@TA8yqhDUMbh$SGt=x3N@u>-9~gz?1RkNLSpt1wz# z#SCUoMyCmh@OQyLT<@R@1LQz3y1E^26$OEF`5HVpXC3eH)&Y7{+=l;#edbR6=3-mz zeQ@{0Hn`9Xa8iN;YkBh~e5DLW?^EeI{AF!6^wJDQyWJ;XVoC@uPd$QhPAy>bC=FO( z=?}kNZeiB%en7q9RC7&BZ0!=kONSHN$ zJ};VZ6w-eT#$+o)xzW3ZaOsc@gf|(3F24@ZHH+YA&2t>@d^2Fz2L9WRZ@ytLy!VMhmz}pkc3@H2J*x(;S7JqQbWXyU-<_4yymw-7z_277-^3%0=Mz%uHTi(`-7& zTfiC0w+bzV!L{*enxELC6hR<5Njg>Tw@nclHP^nYjx?-M-a;)P-G(|52meqcMHXL}aS;DAZ=8dZM+flUi?i@zwJvtH zui;{@JWp;WUAoVyuA(3C>i5Guw51;H4C|{0=V*dX8%Mkq>jv*q61d&wN%*CoEqw5w z#0wi-qq)4IH1AuC5C4_a&26{y6>MYA!$|jvZwD5@imWW?7GQ{7L@&!Md(W7%Yh`xX+=0qHbm+?=Jf>E2bMK0pnRLYVe50;0ltu$ehSIn9;OL>3!p!%TGpA;gU7JvyRZ!h#e2Nw%ajXqxm%iNoK|oG~jehh6LGBJ@-iGemXePj_{IM7+(!z2;6! z0MZry$>k92n7tH5{|j$OF_=iR`xEC)&VzOO$5H#rHP@;mcKm|Q3^w%C9x)s6@Si81 zTu{V!R^C9%tF_{JSnZ-TJd|k&|N5B8;@OR&x+|o2oUS2xE)V5L^kQVqtUoYqX)0eg zsEM38Y_+Q&{kz}cNwuz|y*$ujH8dNT0QQE9!R>Afta5CEUpDD+(qH0~t$fpGZ{`-S z(C072nP+XtLo;#g{wgu&ApF~VPkTnQE?$<&KVRq{<{zeoc%#JuOezx^>&*$C2wRO#c* zFYINn7{M0O+Y>*?$AZvj=hc%mchY}0ZtA0(l))26;-W&}x$TW{=(+@?*@m|pd(-Q8 zcqh$RbY13#Ru6-*^B<0S^G@>aV?3G5wAbi)(n$TAvIeOK<-}n-;6Ix}KK$YUG`{Et zl@I3A>|KM6i}NwMVSi?kuv?r1PXF4KxWf!QOqa1mYjQaGA;>zzfH;KB+Z+Jz`aZ>d z=R9#<+m2As=^;cnJO>>X*z-l1FVXqsaI9IeiBV6eRDY?*+@`SlZzPNNB$ljALgE?f z(Mo09kKT-!1ozdy%!5iAlP9iY{*yM*SvA8nwT0jSF_Re4t^(=W>XI$lVD5YxsP{1} z^ds}io=ERYMw`azs5KNh`8FS`M#GWt_83$WE0{?oJwftlyfr+R3155^orNQlWxV9D z4Zi5LB5q7UtDLj^)6(Nu@ZJxJ0fD>*JB%L=z0)R(`@wbkPU`b_IqdhdeDjuT4?cfnSXZ12x+&LY`tkxCm#daQP#e5sp zL*eyM+Lw5*A&x63pngtfVg~0p1lC+Ksz2N;&nZl>ey%yPkVVa3wt5mEpQ< zg0*TD#92UjI?Gh_(Z;tEQXL`ukOfrDw-o#cmu4rkMVT|OINMlFj80<#+ndY6pKSL= zA+a_TXxj6k)pxMyaT1QRUe49aUJ&P$%vQ_s=+xpQc3N`)r@O=gu{O6^{nXXx!Cm|? zrvap~Y}X%6lZjK}x$y1UZ&5gCq8I!;7{*_l=&HX>-f^4q1a?l}4HvEHCO;b1mi9^} zVB0qB{a-)gx}czzRZ} z2yVd^o^O8O9k@NLc!2y?U zRq`i3-nhlAH4q=E#3)d+!T^43KMEUrb^;xvu~<8K5Ozp;h;)`rusLbBL>dd5>s{bq zy89x%N9vvSgcB37mi7kh(yK1euz5e6|L6jmj5~;8c1v9^An6obdUcc!{@59-EwxD( zJ^;-F+cNx+i^wN=$J>C&OT>lbWg$Qrhk7#{X7!fj(klj(wSFS`BX;a`F5fKJ+Vw`a z$uP$3H!eKBoprgc3sg&?bDVmMJmEVVJa(Q$KFqwrI#AZI#0$+3C#KYsV^4%5ai-e2 zuq!yfESJb@_+?wAj%hTrRw2z6E1uSq}*1(b17XnXqL^{h?zW1RF_>7>~hC|8R6WBbR4 zBKasga3%oB>k5d`_{jymuy<}A`!FYpYw;CWQZWaoKOIMNs14-toHR(8?D-XFwndh} zxh8kuV*F4b&32`k=Zj>R>HR@>QpJbiXG?=Zexa5)f*2`y$9tHTTAvGs|kBZMci zjcL^w6soUo%JsmA<;D=)?}tS5#Hb#EnG})zQj_}&FHk6Rp~37Yynf);IwqhU0#MFU zv*-U(==_oHNwB*_S&DM-Y7`t6J=9l<91y}M7$y>LC*iIgHB4wpm+RpmSYKz*b4HmL zN%z3}%2@nk8Nf^)Hs*7DUHGr|iQN8CFF4UQ9~(9=!^UPG6*@P3^W!cO)3F|NcPPa9 zir|?}nfEC>7mHcN*qviAdwww8&nRYVa!I0R<)brL9lPnxbAjYtyEvV3osSb=6A!o` z%>|0Sr)-bni)Kyo%6{ZZ{i;&V;y*M#Frll$6Pr2B;hS=cfifbR?9vyz4?iUk#!kEj<+pc% zG)Q08F}lY@)?eDd3;!I?WKSP<7dplIM65%~*%Gl61X}!Me#>Xq?H2l;&gGN?h5z%w z_x3&oSb3ptuOzSv(WU|Vm_v^{p*7cD* z4d|JKZga{t%p_tw&>jbeM-{P8N)tIi<1SB4K8G7(=3$5920Z=86?hTQj|rVxy|N9> z`xF%QX`JE6U+h>Z_9Lv^W1>99?i%ewUUI4%8z6rIagW4aLhBV`LAdfe49dREL?5SU z;Du!p^&X?^v&t_mRNb9tfaaYyeyt$s9iu+RnflJzaEO@l}OJIv`Z16M;VKD zujmc#w$|*}2n$(wwBS}>yh!so7y5ota!sFNS08}4Od*HOo_YjSYX zna8YcxT!kyR9jWly+hYsKszDW+<6T<|3FWrY+tuK%gXD<+e~_)kVaz7ive)0=SgrJ z(n9^77*6$jAT*S3*moL5kFMq&*y@$Lfieb=SBMO$2xg@o7CRDji4>&Zl1J2TR z1m9r|IL)9$&q6HZgr=1!Q9p}W17ZR8_H=!^dDG6KB?PzV&8cn$@9%iahOZVw=-vJJ<$R3L0JVI_c4F}Bc=sEGXCfv=kW?HlA#Tl&z8c5UR%;(>p39+BqmMGOmkE48$Txx&oqp zN515GeA{J_+41^-&9p}~mo+@xaKWe6{Cw$sKK14VqIz%W~R_G*N&a~L*UUuPWw1!{^KOl z{z&9pnSO`S{)I2U?*fAcedNTK+;K=U{`qtQucdtl+P$I3)08*SZu4U>|KY+ZufpDe zomJvN>fg;M>Yev(onUR5xYTpzf843 zea}!?uqbgYWitctsPY7#V|$_Li95WN&2dS1Anndj@QcVI5zDuT{end2fR2~Dfcv4X zZ1SBR;O;jX583Z#w5O0MOM&C%qx|bxF0wVF-5@6og0c%vY^h5ylFqBN9|!7VCYU+O z%R!-g1S6jgVyU;E;?O%9(t9$Wf1$rqYD9web>!6iy2Za&N8;b+l)>s?crnq;i zKX&8$KVbTqUVlv&Dc<0P5}u}hJXnkhkq;-sY<>jO>=0&{!DK$p=8NHII2EJ(aEd8D z|5ub`NH5f=y&L4<2;U)%ZgHj2k2}0m;k?ti)4&qRu=d1C@y%zcgk&WM1r-w&f_7+` za%^c#!AfJ9QB}*_xzfjUzD5rPc-XkVd7|AKZHl9lx9|SrHWo{?77LcQtOnh(wes@s zr9pCW+qVrSc{{{=oR Union[pd.DataFrame, list[Document]]: suffix = data_path.suffix if ".xlsx" == suffix: data = pd.read_excel(data_path) @@ -37,14 +33,13 @@ def read_data(data_path: Path): elif ".json" == suffix: data = pd.read_json(data_path) elif suffix in (".docx", ".doc"): - data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load() + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() elif ".txt" == suffix: - data = TextLoader(str(data_path)).load() - text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0) - texts = text_splitter.split_documents(data) - data = texts + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() + node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0) + data = node_parser.get_nodes_from_documents(data) elif ".pdf" == suffix: - data = UnstructuredPDFLoader(str(data_path), mode="elements").load() + data = PDFReader.load_data(str(data_path)) else: raise NotImplementedError("File format not supported.") return data @@ -146,9 +141,9 @@ class IndexableDocument(Document): metadatas.append({}) return docs, metadatas - def _get_docs_and_metadatas_by_langchain(self) -> (list, list): + def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list): data = self.data - docs = [i.page_content for i in data] + docs = [i.text for i in data] metadatas = [i.metadata for i in data] return docs, metadatas @@ -156,7 +151,7 @@ class IndexableDocument(Document): if isinstance(self.data, pd.DataFrame): return self._get_docs_and_metadatas_by_df() elif isinstance(self.data, list): - return self._get_docs_and_metadatas_by_langchain() + return self._get_docs_and_metadatas_by_llamaindex() else: raise NotImplementedError("Data type not supported for metadata extraction.") diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index ddc1d626b..129da4f4f 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -39,8 +39,8 @@ class LocalStore(BaseStore, ABC): self.store = self.write() def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): - index_file = self.cache_dir / f"{self.fname}{index_ext}" - store_file = self.cache_dir / f"{self.fname}{pkl_ext}" + index_file = self.cache_dir / "default__vector_store.json" + store_file = self.cache_dir / "docstore.json" return index_file, store_file @abstractmethod diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index 2359917d5..2136e49db 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -7,10 +7,14 @@ """ import asyncio from pathlib import Path -from typing import Optional +from typing import Any, Optional -from langchain.vectorstores import FAISS -from langchain_core.embeddings import Embeddings +import faiss +from llama_index import VectorStoreIndex, load_index_from_storage +from llama_index.embeddings import BaseEmbedding +from llama_index.schema import Document, QueryBundle, TextNode +from llama_index.storage import StorageContext +from llama_index.vector_stores import FaissVectorStore from metagpt.document import IndexableDocument from metagpt.document_store.base_store import LocalStore @@ -20,36 +24,52 @@ from metagpt.utils.embedding import get_embedding class FaissStore(LocalStore): def __init__( - self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None + self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None ): self.meta_col = meta_col self.content_col = content_col self.embedding = embedding or get_embedding() + self.store: VectorStoreIndex super().__init__(raw_data, cache_dir) - def _load(self) -> Optional["FaissStore"]: - index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss + def _load(self) -> Optional["VectorStoreIndex"]: + index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss if not (index_file.exists() and store_file.exists()): logger.info("Missing at least one of index_file/store_file, load failed and return None") return None + vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir) + storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store) + index = load_index_from_storage(storage_context) - return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname) + return index - def _write(self, docs, metadatas): - store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas) - return store + def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex: + assert len(docs) == len(metadatas) + texts_embeds = self.embedding.get_text_embedding_batch(docs) + documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)] + + [TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)] + # doc_store = SimpleDocumentStore() + # doc_store.add_documents(nodes) + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context) + + return index def persist(self): - self.store.save_local(self.raw_data_path.parent, self.fname) + self.store.storage_context.persist(self.cache_dir) + + def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs): + retriever = self.store.as_retriever(similarity_top_k=k) + rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query))) - def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs): - rsp = self.store.similarity_search(query, k=k, **kwargs) logger.debug(rsp) if expand_cols: - return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp])) + return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp])) else: - return str(sep.join([f"{x.page_content}" for x in rsp])) + return str(sep.join([f"{x.node.text}" for x in rsp])) async def asearch(self, *args, **kwargs): return await asyncio.to_thread(self.search, *args, **kwargs) @@ -67,8 +87,12 @@ class FaissStore(LocalStore): def add(self, texts: list[str], *args, **kwargs) -> list[str]: """FIXME: Currently, the store is not updated after adding.""" - return self.store.add_texts(texts) + texts_embeds = self.embedding.get_text_embedding_batch(texts) + nodes = [TextNode(embedding=embed) for embed in texts_embeds] + self.store.insert_nodes(nodes) + + return [] def delete(self, *args, **kwargs): - """Currently, langchain does not provide a delete interface.""" + """Currently, faiss does not provide a delete interface.""" raise NotImplementedError diff --git a/metagpt/memory/memory2.py b/metagpt/memory/memory2.py new file mode 100644 index 000000000..f33b740de --- /dev/null +++ b/metagpt/memory/memory2.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : memory mechanism including store/retrieval/rank + +from typing import Union, Optional +from pydantic import Field, BaseModel + +from metagpt.memory.memory_network import MemoryNetwork +from metagpt.memory.schema import MemoryNode +from metagpt.schema import Message + + +class Memory(BaseModel): + mem_network: Optional[MemoryNetwork] = Field(default_factory=MemoryNetwork, description="the network to store memory") + + def add_msg(self, message: Message): + mem_node = MemoryNode.create_mem_node_from_message(message) + self.mem_network.add_mem(mem_node) + + def add_msgs(self, messages: list[Message]): + for msg in messages: + self.add_msg(msg) diff --git a/metagpt/memory/memory_network.py b/metagpt/memory/memory_network.py new file mode 100644 index 000000000..00bc2ba78 --- /dev/null +++ b/metagpt/memory/memory_network.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the memory network to store memory segment + +from pydantic import Field, BaseModel + +from metagpt.memory.schema import MemorySegment, MemoryNode + + +class MemoryNetwork(BaseModel): + mem_seg: MemorySegment = Field(default_factory=MemorySegment, description="the memory segment to store memory nodes") + + def add_mem(self, mem_node: MemoryNode): + self.mem_seg.add_mem_node(mem_node) + + def add_mems(self, mem_nodes: list[MemoryNode]): + for mem_node in mem_nodes: + self.add_mem(mem_node) diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index c029d027b..f096cec72 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -5,11 +5,8 @@ """ from pathlib import Path -from typing import Optional -from langchain.embeddings import OpenAIEmbeddings -from langchain.vectorstores.faiss import FAISS -from langchain_core.embeddings import Embeddings +from llama_index.embeddings import BaseEmbedding from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore @@ -23,29 +20,17 @@ class MemoryStorage(FaissStore): The memory storage with Faiss as ANN search engine """ - def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None): + def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None): self.role_id: str = None self.role_mem_path: str = None self.mem_ttl: int = mem_ttl # later use self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False - self.embedding = embedding or OpenAIEmbeddings() - self.store: FAISS = None # Faiss engine - @property def is_initialized(self) -> bool: return self._initialized - def _load(self) -> Optional["FaissStore"]: - index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss - - if not (index_file.exists() and store_file.exists()): - logger.info("Missing at least one of index_file/store_file, load failed and return None") - return None - - return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id) - def recover_memory(self, role_id: str) -> list[Message]: self.role_id = role_id self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/") @@ -69,6 +54,7 @@ class MemoryStorage(FaissStore): return None, None index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}") storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}") + self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id) return index_fpath, storage_fpath def persist(self): diff --git a/metagpt/memory/schema.py b/metagpt/memory/schema.py new file mode 100644 index 000000000..610f54bd0 --- /dev/null +++ b/metagpt/memory/schema.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the memory schema definition + +from datetime import datetime +from enum import Enum +from typing import Optional, Union +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class MemNodeType(Enum): + OBSERVE = "observe" # memory from observation + THINK = "think" # memory from self-think/reflect + + +class MemoryNode(BaseModel): + """base unit of memory abstraction""" + + mem_node_id: UUID = Field(default_factory=uuid4(), description="unique node id") + parent_node_id: Optional[str] = Field(default=None, description="memory's parent memory node id") + node_type: MemNodeType = Field(default=MemNodeType.OBSERVE, description="memory node type") + + content: str = Field(default="", description="the memory content") + summary: Optional[str] = Field(default=None, description="the summary of the content by providers") + keywords: list[str] = Field(default=[], description="the extracted keywords of the content") + embedding: list[float] = Field(default=[], description="the embeeding of the content") + + raw_path: Optional[str] = Field(default=None, description="the relative path of the media like image") + raw_corpus: list[Union[str, dict, tuple]] = Field(default=[], description="the raw corpus of the memory") + + create_at: datetime = Field(default_factory=datetime, description="the memory create time") + access_at: datetime = Field(default_factory=datetime, description="the memory last access time") + expire_at: datetime = Field(default_factory=datetime, description="the memory expire time due to a TTL") + + importance: int = Field(default=0, ge=0, le=10, description="the memory importance") + access_cnt: int = Field(default=0, description="the memory acess count time") + + @classmethod + def create_mem_node( + cls, + content: str, + summary: Optional[str] = None, + keywords: list[str] = [], + node_type: MemNodeType = MemNodeType.OBSERVE, + ): + pass + + @classmethod + def create_mem_node_from_message(cls, message: "Message"): + pass + + +class MemorySegment(BaseModel): + """segment abstraction to store memory_node""" + + mem_nodes: list[MemoryNode] = Field(default=[], description="memory list to store MemoryNode") + + def add_mem_node(self, mem_node: MemoryNode): + self.mem_nodes.append(mem_node) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 47a4f45a7..a0f63124c 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -102,12 +102,6 @@ class RoleContext(BaseModel): ) # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 - def check(self, role_id: str): - # if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory: - # self.long_term_memory.recover_memory(role_id, self) - # self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation - pass - @property def important_memory(self) -> list[Message]: """Retrieve information corresponding to the attention action.""" @@ -300,8 +294,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): buffer during _observe. """ self.rc.watch = {any_to_str(t) for t in actions} - # check RoleContext after adding watch actions - self.rc.check(self.role_id) def is_watch(self, caused_by: str): return caused_by in self.rc.watch diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py index 21d62948c..3b5465f99 100644 --- a/metagpt/utils/embedding.py +++ b/metagpt/utils/embedding.py @@ -5,12 +5,12 @@ @Author : alexanderwu @File : embedding.py """ -from langchain_community.embeddings import OpenAIEmbeddings +from llama_index.embeddings import OpenAIEmbedding from metagpt.config2 import config -def get_embedding(): +def get_embedding() -> OpenAIEmbedding: llm = config.get_openai_llm() - embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url) + embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url) return embedding diff --git a/requirements.txt b/requirements.txt index 0a54236f0..1c3ebeca0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.8.4 +aiohttp==3.8.6 #azure_storage==0.37.0 channels==4.0.0 # chromadb @@ -11,11 +11,11 @@ typer==0.9.0 # godot==0.1.1 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 -langchain==0.0.352 +llama-index==0.9.31 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai==1.6.0 +openai==1.6.1 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index 7e2979bd4..63744ac91 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -25,7 +25,7 @@ async def test_search_json(): @pytest.mark.asyncio async def test_search_xlsx(): - store = FaissStore(EXAMPLE_PATH / "example.xlsx") + store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) @@ -36,5 +36,5 @@ async def test_search_xlsx(): async def test_write(): store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() - assert _faiss_store.docstore - assert _faiss_store.index + assert _faiss_store.storage_context.docstore + assert _faiss_store.storage_context.vector_store.client