From 7bb6c8c750bafb634f7da4e473cb7dcc66aa1d1c Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 9 Dec 2024 13:21:28 -0800 Subject: [PATCH] add test and fix bug --- model_server/src/core/function_calling.py | 20 +- model_server/tests/core/test_cases.json | 949 ------------------ model_server/tests/core/test_hallucination.py | 205 ++-- 3 files changed, 109 insertions(+), 1065 deletions(-) delete mode 100644 model_server/tests/core/test_cases.json diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index eec24dc4..6b3bc6d6 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -435,6 +435,8 @@ class ArchFunctionHandler(ArchBaseHandler): """ Engage parameter gathering for tool calls """ + + # TODO: log enaging parameter gathering prefill_response = self.client.chat.completions.create( messages=self._add_prefill_message(messages), model=self.model_name, @@ -472,32 +474,34 @@ class ArchFunctionHandler(ArchBaseHandler): ) # initialize the hallucination handler, which is an iterator - hallu_handler = HallucinationStateHandler( + self.hallu_handler = HallucinationStateHandler( response_iterator=response, function=req.tools ) model_response, has_tool_call = "", None - for token in hallu_handler: + for token in self.hallu_handler: # check if the first token is - if len(hallu_handler.tokens) > 0 and has_tool_call == False: - if hallu_handler.tokens[0] == "": + if len(self.hallu_handler.tokens) > 0 and has_tool_call == None: + if self.hallu_handler.tokens[0] == "": has_tool_call = True else: has_tool_call = False break - if hallu_handler.hallucination == True: + + # if the model is hallucinating, start parameter gathering + if self.hallu_handler.hallucination == True: prefill_response = self._engage_parameter_gathering(messages) model_response = prefill_response.choices[0].message.content break # start parameter gathering if the model is not generating tool calls - if hallu_handler.hallucination == False: - model_response = "".join(hallu_handler.tokens) + if self.hallu_handler.hallucination == False: + model_response = "".join(self.hallu_handler.tokens) # start parameter gathering if the model is not generating tool calls if has_tool_call is False: - prefill_response = await self._engage_parameter_gathering(messages) + prefill_response = self._engage_parameter_gathering(messages) model_response = prefill_response.choices[0].message.content # Extract tool calls from model response diff --git a/model_server/tests/core/test_cases.json b/model_server/tests/core/test_cases.json deleted file mode 100644 index 229b4ea1..00000000 --- a/model_server/tests/core/test_cases.json +++ /dev/null @@ -1,949 +0,0 @@ -[ - { - "case": "tool_call_halluciation", - "tokens": [ - "" - ], - "expect": 1, - "logprobs": [ - [ - -0.3333307206630707, - -1.5310522317886353, - -3.5098977088928223, - -3.9004578590393066, - -5.775152683258057, - -5.814209461212158, - -5.9574151039123535, - -6.0094895362854, - -6.0094895362854, - -6.673445224761963 - ] - ] - }, - { - "case": "parameter_value_hallucination", - "expect": 0, - "tokens": [ - "", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Sea", - ",", - " Australia", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "1", - "'}}\n", - "" - ], - "logprobs": [ - [ - -0.008103232830762863, - -5.085402488708496, - -6.777836799621582, - -7.558959007263184, - -9.850253105163574, - -10.266852378845215, - -10.540244102478027, - -10.722506523132324, - -10.800618171691895, - -10.917786598205566 - ], - [ - 0.0, - -23.25142478942871, - -25.139137268066406, - -26.2847843170166, - -28.992677688598633, - -29.070789337158203, - -29.55248260498047, - -29.91700553894043, - -30.20341682434082, - -30.307567596435547 - ], - [ - 0.0, - -21.66313934326172, - -23.06916046142578, - -23.32953453063965, - -25.65988540649414, - -25.985353469848633, - -26.519121170043945, - -27.07892417907715, - -27.977216720581055, - -28.458908081054688 - ], - [ - 0.0, - -28.094383239746094, - -28.56305694580078, - -29.109844207763672, - -29.44832992553711, - -31.79170036315918, - -32.0, - -32.05207443237305, - -32.31244659423828, - -32.364524841308594 - ], - [ - 0.0, - -30.489830017089844, - -31.140766143798828, - -31.81774139404297, - -34.525634765625, - -35.8275032043457, - -36.504478454589844, - -39.05614471435547, - -40.123680114746094, - -40.696502685546875 - ], - [ - 0.0, - -25.646865844726562, - -26.66232681274414, - -27.781936645507812, - -28.979660034179688, - -31.140764236450195, - -31.92188835144043, - -31.973962783813477, - -33.04149627685547, - -33.58828353881836 - ], - [ - 0.0, - -23.511798858642578, - -24.136695861816406, - -25.230268478393555, - -25.777053833007812, - -25.80309295654297, - -26.45402717590332, - -26.636289596557617, - -26.740440368652344, - -26.896663665771484 - ], - [ - 0.0, - -22.366153717041016, - -24.683483123779297, - -26.610252380371094, - -26.610252380371094, - -27.313264846801758, - -27.67778778076172, - -28.510986328125, - -28.615135192871094, - -29.13588523864746 - ], - [ - 0.0, - -22.52237319946289, - -24.292919158935547, - -24.344993591308594, - -24.39706802368164, - -24.73555564880371, - -29.943042755126953, - -29.969079971313477, - -30.021154403686523, - -30.0341739654541 - ], - [ - 0.0, - -30.17738151550293, - -30.411718368530273, - -30.88039207458496, - -30.984540939331055, - -31.270952224731445, - -31.895851135253906, - -32.46867370605469, - -32.624900817871094, - -33.484134674072266 - ], - [ - 0.0, - -28.146459579467773, - -29.396255493164062, - -30.099267959594727, - -31.127744674682617, - -31.179821014404297, - -32.807159423828125, - -33.7445068359375, - -33.770545959472656, - -34.069976806640625 - ], - [ - 0.0, - -26.323841094970703, - -26.558177947998047, - -30.515867233276367, - -30.932466506958008, - -31.37510108947754, - -31.531326293945312, - -31.70056915283203, - -32.065093994140625, - -32.364524841308594 - ], - [ - 0.0, - -26.922698974609375, - -30.28152847290039, - -31.505287170410156, - -33.30187225341797, - -33.73148727416992, - -34.27827453613281, - -34.33034896850586, - -34.460533142089844, - -34.720909118652344 - ], - [ - 0.0, - -21.532955169677734, - -26.94873809814453, - -29.109848022460938, - -30.80228042602539, - -31.55736541748047, - -33.484134674072266, - -34.681854248046875, - -35.384864807128906, - -35.853538513183594 - ], - [ - 0.0, - -19.502033233642578, - -20.46541976928711, - -24.110658645629883, - -24.501218795776367, - -25.256305694580078, - -25.82912826538086, - -25.881202697753906, - -26.063465118408203, - -26.063465118408203 - ], - [ - 0.0, - -24.37103271484375, - -25.256305694580078, - -25.933277130126953, - -26.714401245117188, - -28.2506103515625, - -31.010576248168945, - -32.07810974121094, - -34.62977981567383, - -35.241661071777344 - ], - [ - -1.1920922133867862e-06, - -14.398697853088379, - -14.424736976623535, - -17.158666610717773, - -17.41904067993164, - -18.200162887573242, - -18.434499740600586, - -18.66883659362793, - -19.71033477783203, - -19.71033477783203 - ], - [ - -0.0001445904199499637, - -8.98305892944336, - -11.35246467590332, - -13.1490478515625, - -13.669795989990234, - -14.073375701904297, - -14.516012191772461, - -14.555068969726562, - -15.622602462768555, - -15.635622024536133 - ], - [ - -0.44747352600097656, - -1.0202960968017578, - -8.467000961303711, - -10.914518356323242, - -11.25300407409668, - -11.435266494750977, - -12.346576690673828, - -13.075624465942383, - -13.12769889831543, - -13.231849670410156 - ], - [ - -3.123767137527466, - -1.1188862323760986, - -1.639634370803833, - -2.0562336444854736, - -2.8633930683135986, - -2.9675419330596924, - -3.4882919788360596, - -3.69659161567688, - -4.217339515686035, - -4.243376731872559 - ], - [ - -7.199982064776123e-05, - -9.76410961151123, - -11.144091606140137, - -16.507802963256836, - -17.132701873779297, - -17.44515037536621, - -17.9138240814209, - -18.33042335510254, - -18.9162654876709, - -19.39795684814453 - ], - [ - 0.0, - -22.991050720214844, - -23.824249267578125, - -24.969894409179688, - -25.46460723876953, - -25.829130172729492, - -26.480066299438477, - -26.909683227539062, - -27.33930206298828, - -27.391376495361328 - ], - [ - -0.21928852796554565, - -1.625309705734253, - -9.775025367736816, - -12.977627754211426, - -16.388530731201172, - -17.091541290283203, - -19.044347763061523, - -19.38283348083496, - -19.460947036743164, - -19.59113311767578 - ], - [ - 0.0, - -24.006507873535156, - -27.443450927734375, - -27.729862213134766, - -28.12042236328125, - -28.276647567749023, - -28.927583694458008, - -30.099267959594727, - -31.479251861572266, - -32.07810974121094 - ], - [ - 0.0, - -18.17412567138672, - -18.772987365722656, - -21.689178466796875, - -21.92351531982422, - -23.7200984954834, - -23.79821014404297, - -23.79821014404297, - -24.032546997070312, - -25.308382034301758 - ], - [ - -0.12947827577590942, - -2.1083219051361084, - -12.419143676757812, - -15.23118782043457, - -15.595710754394531, - -15.830047607421875, - -17.001731872558594, - -17.60059356689453, - -18.121341705322266, - -18.251529693603516 - ], - [ - 0.0, - -19.449962615966797, - -24.371034622192383, - -24.917821884155273, - -25.529701232910156, - -25.85516929626465, - -26.037429809570312, - -26.115543365478516, - -26.623271942138672, - -26.649309158325195 - ], - [ - -0.03332124650478363, - -3.4181859493255615, - -15.759925842285156, - -15.812002182006836, - -16.593124389648438, - -17.894996643066406, - -18.09027671813965, - -18.79328727722168, - -19.144792556762695, - -20.147233963012695 - ], - [ - 0.0, - -21.142393112182617, - -22.157852172851562, - -23.511798858642578, - -24.657445907592773, - -25.021968841552734, - -25.5427188873291, - -25.59479331970215, - -25.75101661682129, - -25.95931625366211 - ], - [ - 0.0, - -23.04312515258789, - -24.94385528564453, - -26.323841094970703, - -27.54759979248047, - -28.563060760498047, - -29.786819458007812, - -30.620018005371094, - -30.69812774658203, - -31.08869171142578 - ], - [ - 0.0, - -26.167617797851562, - -28.771360397338867, - -29.55248260498047, - -30.906429290771484, - -31.114728927612305, - -31.414159774780273, - -31.622459411621094, - -31.713590621948242, - -31.726608276367188 - ], - [ - -0.05012698099017143, - -3.018392562866211, - -11.740934371948242, - -13.146955490112305, - -13.797887802124023, - -14.943536758422852, - -16.037107467651367, - -16.375595092773438, - -16.714080810546875, - -17.36501693725586 - ], - [ - -0.9704352021217346, - -0.7360983490943909, - -2.1941938400268555, - -4.225115776062012, - -5.0062360763549805, - -5.2666120529174805, - -5.839434623718262, - -7.2714948654174805, - -8.33902645111084, - -8.495253562927246 - ], - [ - -0.014467108063399792, - -4.258565902709961, - -8.789079666137695, - -10.429437637329102, - -10.793962478637695, - -11.835458755493164, - -11.939607620239258, - -13.31959342956543, - -13.866378784179688, - -15.038063049316406 - ], - [ - 0.0, - -20.08787727355957, - -21.350692749023438, - -21.415786743164062, - -21.50691795349121, - -21.50691795349121, - -22.7176570892334, - -24.13669776916504, - -24.188772201538086, - -24.34499740600586 - ] - ] - }, - { - "case": "fail_case", - "expect": 0, - "tokens": [ - "", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Seattle", - ",", - " WA", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "7", - "'}}\n", - "" - ], - "logprobs": [ - [ - -0.00013815402053296566, - -9.113236427307129, - -10.571331977844238, - -14.099404335021973, - -14.28166675567627, - -15.583537101745605, - -15.81787395477295, - -16.143341064453125, - -16.143341064453125, - -16.260509490966797 - ], - [ - 0.0, - -26.896663665771484, - -27.32628059387207, - -27.41741180419922, - -32.07810974121094, - -32.07810974121094, - -32.28641128540039, - -32.29943084716797, - -32.44263458251953, - -32.520748138427734 - ], - [ - 0.0, - -22.444263458251953, - -24.527257919311523, - -27.15703773498535, - -28.016273498535156, - -28.2506103515625, - -28.693246841430664, - -29.070789337158203, - -29.565500259399414, - -29.812854766845703 - ], - [ - 0.0, - -27.860050201416016, - -28.641170501708984, - -29.448333740234375, - -30.932466506958008, - -31.63547706604004, - -32.33848571777344, - -32.85923767089844, - -33.17168426513672, - -33.45809555053711 - ], - [ - 0.0, - -31.81774139404297, - -31.895854949951172, - -32.05207824707031, - -35.43694305419922, - -36.3482551574707, - -38.61351013183594, - -39.26444625854492, - -40.61839294433594, - -41.71196365356445 - ], - [ - 0.0, - -27.33930206298828, - -27.834014892578125, - -28.849472045898438, - -30.567943572998047, - -32.98942565917969, - -33.067535400390625, - -33.067535400390625, - -35.67127990722656, - -35.69731903076172 - ], - [ - 0.0, - -25.33441925048828, - -26.063465118408203, - -26.219690322875977, - -26.2457275390625, - -26.53213882446289, - -27.365337371826172, - -28.354759216308594, - -28.667207717895508, - -28.74532127380371 - ], - [ - 0.0, - -24.423107147216797, - -24.579330444335938, - -26.81855010986328, - -28.12042236328125, - -28.32872200012207, - -28.61513328552246, - -29.16191864013672, - -29.187957763671875, - -29.240032196044922 - ], - [ - 0.0, - -22.027664184570312, - -23.850284576416016, - -23.980472564697266, - -24.292922973632812, - -24.787633895874023, - -29.279088973999023, - -29.55248260498047, - -29.903987884521484, - -30.190399169921875 - ], - [ - 0.0, - -31.609439849853516, - -31.817739486694336, - -32.54678726196289, - -32.676971435546875, - -32.781124114990234, - -32.98942565917969, - -33.106590270996094, - -33.57526397705078, - -34.369407653808594 - ], - [ - 0.0, - -29.34418296813965, - -29.63059425354004, - -30.021156311035156, - -30.984540939331055, - -33.21073913574219, - -34.30431365966797, - -34.56468963623047, - -34.70789337158203, - -34.79902648925781 - ], - [ - 0.0, - -25.438566207885742, - -25.69894027709961, - -30.190397262573242, - -30.802276611328125, - -31.58340072631836, - -31.609437942504883, - -31.64849281311035, - -31.973960876464844, - -32.29943084716797 - ], - [ - 0.0, - -27.157039642333984, - -32.104148864746094, - -32.33848571777344, - -34.04393768310547, - -34.12205505371094, - -34.40846252441406, - -34.42148208618164, - -34.772987365722656, - -34.87713623046875 - ], - [ - 0.0, - -24.813671112060547, - -26.974777221679688, - -31.010578155517578, - -31.08869171142578, - -32.1822624206543, - -35.33279037475586, - -35.489013671875, - -36.999183654785156, - -37.88446044921875 - ], - [ - 0.0, - -20.46541976928711, - -20.647682189941406, - -23.069164276123047, - -24.136699676513672, - -25.438570022583008, - -25.646869659423828, - -26.193655014038086, - -26.297805786132812, - -26.506103515625 - ], - [ - 0.0, - -27.18307113647461, - -28.30268096923828, - -28.56305694580078, - -29.526439666748047, - -32.416595458984375, - -35.202598571777344, - -36.426361083984375, - -39.31651306152344, - -39.38160705566406 - ], - [ - 0.0, - -18.7469482421875, - -20.100894927978516, - -21.402767181396484, - -21.428804397583008, - -22.20992660522461, - -22.34011459350586, - -22.730674743652344, - -23.069162368774414, - -23.980472564697266 - ], - [ - -3.576278118089249e-07, - -15.2579345703125, - -16.481693267822266, - -17.991863250732422, - -19.215621948242188, - -20.25712013244629, - -21.350692749023438, - -22.314077377319336, - -22.496337890625, - -22.938974380493164 - ], - [ - -0.08506780862808228, - -2.506549835205078, - -14.848289489746094, - -15.473188400268555, - -16.33242416381836, - -16.358461380004883, - -16.566761016845703, - -17.03543472290039, - -17.686370849609375, - -17.816556930541992 - ], - [ - -0.0194891095161438, - -4.445854187011719, - -5.591499328613281, - -5.956024169921875, - -6.685070037841797, - -13.142353057861328, - -13.558952331542969, - -15.173273086547852, - -15.303461074829102, - -15.85024642944336 - ], - [ - -0.0005990855861455202, - -7.4212646484375, - -15.675132751464844, - -15.72720718383789, - -16.76870346069336, - -16.76870346069336, - -17.706050872802734, - -18.669435501098633, - -19.398483276367188, - -19.658857345581055 - ], - [ - 0.0, - -24.110658645629883, - -25.829130172729492, - -26.011390686035156, - -26.011390686035156, - -26.532140731811523, - -26.58421516418457, - -27.651750564575195, - -27.75589942932129, - -28.055330276489258 - ], - [ - -1.1408883333206177, - -0.38580334186553955, - -7.494022369384766, - -12.519245147705078, - -14.576202392578125, - -16.034297943115234, - -16.945608139038086, - -17.908992767333984, - -18.664077758789062, - -19.34105110168457 - ], - [ - 0.0, - -26.688365936279297, - -29.83889389038086, - -30.177383422851562, - -30.64605712890625, - -31.244916915893555, - -31.270954132080078, - -32.83319854736328, - -34.655818939208984, - -34.89015579223633 - ], - [ - 0.0, - -18.929210662841797, - -19.16354751586914, - -23.589908599853516, - -24.683481216430664, - -24.995929718017578, - -25.516677856445312, - -25.542715072631836, - -25.77705192565918, - -26.063465118408203 - ], - [ - -0.2519786059856415, - -1.5017764568328857, - -12.437495231628418, - -15.457839012145996, - -15.744250297546387, - -16.837820053100586, - -17.41064453125, - -17.56686782836914, - -17.61894416809082, - -18.035541534423828 - ], - [ - 0.0, - -20.517494201660156, - -24.683483123779297, - -25.67290496826172, - -26.58421516418457, - -27.651750564575195, - -27.781936645507812, - -27.912124633789062, - -28.09438705444336, - -28.445892333984375 - ], - [ - -3.40932747349143e-05, - -10.284820556640625, - -18.252273559570312, - -20.17904281616211, - -21.663175582885742, - -22.027700424194336, - -22.288074493408203, - -22.704673767089844, - -23.12127113342285, - -23.277496337890625 - ], - [ - 0.0, - -22.60049057006836, - -25.46460723876953, - -25.829130172729492, - -26.063467025756836, - -27.287227630615234, - -27.391376495361328, - -27.4694881439209, - -27.67778778076172, - -28.055330276489258 - ], - [ - 0.0, - -23.902362823486328, - -28.823436737060547, - -29.240036010742188, - -29.31814956665039, - -29.917007446289062, - -30.021160125732422, - -31.21887969970703, - -32.416603088378906, - -32.416603088378906 - ], - [ - 0.0, - -28.641170501708984, - -31.947925567626953, - -32.59886169433594, - -33.848655700683594, - -34.109031677246094, - -34.73393249511719, - -35.02033996582031, - -35.02033996582031, - -36.074859619140625 - ], - [ - -0.013183215633034706, - -4.335395336151123, - -19.619365692138672, - -20.035964965820312, - -20.244266510009766, - -21.311800003051758, - -21.441987991333008, - -22.561595916748047, - -23.108383178710938, - -23.264606475830078 - ], - [ - -8.344646857949556e-07, - -14.190400123596191, - -15.9088716506958, - -18.17412567138672, - -18.46053695678711, - -18.46053695678711, - -18.512611389160156, - -18.90317153930664, - -19.059398651123047, - -19.085433959960938 - ], - [ - 0.0, - -17.70545196533203, - -18.903175354003906, - -20.829944610595703, - -22.574451446533203, - -22.860862731933594, - -23.069162368774414, - -23.32953643798828, - -23.694061279296875, - -24.188772201538086 - ], - [ - 0.0, - -20.022781372070312, - -21.038240432739258, - -21.220502853393555, - -22.496337890625, - -22.769729614257812, - -23.589908599853516, - -23.65500259399414, - -23.94141387939453, - -24.266881942749023 - ] - ] - } -] diff --git a/model_server/tests/core/test_hallucination.py b/model_server/tests/core/test_hallucination.py index 3de80c37..fcbbd962 100644 --- a/model_server/tests/core/test_hallucination.py +++ b/model_server/tests/core/test_hallucination.py @@ -1,20 +1,14 @@ -import json -import pytest import os +from src.commons.globals import handler_map +from src.core.model_utils import ChatMessage, Message +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch +from src.main import app +from src.commons.globals import handler_map -from src.core.hallucination_handler import HallucinationStateHandler - - -# Get the directory of the current file -current_dir = os.path.dirname(__file__) - -# Construct the full path to the JSON file -json_file_path = os.path.join(current_dir, "test_cases.json") - -with open(json_file_path) as f: - test_cases = json.load(f) - +# define function get_weather_api = { "type": "function", "function": { @@ -43,111 +37,106 @@ get_weather_api = { }, }, } -function_description = get_weather_api["function"] -if type(function_description) != list: - function_description = [get_weather_api["function"]] -# [TODO] Review: update the following code -@pytest.mark.parametrize("case", test_cases) -def test_hallucination(case): - state = HallucinationStateHandler( - response_iterator=None, function=function_description +def get_hallucination_data_complex(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle?") + message2 = Message( + role="assistant", content="Can you specify the unit you want the weather in?" ) - for token, logprob in zip(case["tokens"], case["logprobs"]): - if token != "": - state.append_and_check_token_hallucination(token, logprob) - if state.hallucination: - break - assert state.hallucination == case["expect"] + message3 = Message(role="user", content="In celcius please!") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1, message2, message3], tools=tools) + + return req, True, True, True -# [TODO] Review: update the following code -@pytest.mark.parametrize("is_hallucinate_sample", [True, False]) -def test_hallucination_prompt(is_hallucinate_sample): - TASK_PROMPT = """ - You are a helpful assistant. - """.strip() +def get_hallucination_data_easy(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle?") - TOOL_PROMPT = """ - # Tools + # Create a list of tools + tools = [get_weather_api] - You may call one or more functions to assist with the user query. + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) - You are provided with function signatures within XML tags: - - {tool_text} - - """.strip() + # model will hallucinate + return req, True, True, True - FORMAT_PROMPT = """ - For each function call, return a json object with function name and arguments within XML tags: - - {"name": , "arguments": } - - """.strip() - def convert_tools(tools): - return "\n".join([json.dumps(tool) for tool in tools]) +def get_hallucination_data_medium(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in?") - def format_prompt(tools): - tool_text = convert_tools(tools) + # Create a list of tools + tools = [get_weather_api] - return ( - TASK_PROMPT - + "\n\n" - + TOOL_PROMPT.format(tool_text=tool_text) - + "\n\n" - + FORMAT_PROMPT - + "\n" + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + # first token will not be tool call + return req, True, False, True + + +def get_complete_data(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle in 7 days?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + return req, True, False, False + + +def get_irrelevant_data(): + # Create instances of the Message class + message1 = Message(role="user", content="What is 1+1?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + return req, False, False, False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "get_data_func", + [ + get_hallucination_data_complex, + get_hallucination_data_easy, + get_hallucination_data_medium, + get_complete_data, + get_irrelevant_data, + ], +) +async def test_function_calling(get_data_func): + req, intent, hallucination, parameter_gathering = get_data_func() + + intent_response = await handler_map["Arch-Intent"].chat_completion(req) + + assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent + + if intent: + function_calling_response = await handler_map["Arch-Function"].chat_completion( + req ) + assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination + response_txt = function_calling_response.choices[0].message.content - openai_format_tools = [get_weather_api] - - system_prompt = format_prompt(openai_format_tools) - - from openai import OpenAI - - client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") - - # List models API - model = client.models.list().data[0].id - assert model == "Arch-Function" - if not is_hallucinate_sample: - messages = [ - {"role": "system", "content": system_prompt}, - # {"role": "user", "content": "can you help me check weather?"}, - {"role": "user", "content": "How is the weather in Seattle in 7 days?"}, - # {"role": "assistant", "content": "Of course!"}, - # {"role": "user", "content": "Seattle please"} - ] - else: - messages = [ - {"role": "system", "content": system_prompt}, - # {"role": "user", "content": "can you help me check weather?"}, - {"role": "user", "content": "How is the weather in Seattle in days?"}, - # {"role": "assistant", "content": "Of course!"}, - # {"role": "user", "content": "Seattle please"} - ] - - extra_body = { - "temperature": 0.6, - "top_p": 1.0, - "top_k": 50, - # "continue_final_message": True, - # "add_generation_prompt": False, - "logprobs": True, - "top_logprobs": 10, - } - - resp = client.chat.completions.create( - model="Arch-Function", messages=messages, extra_body=extra_body, stream=True - ) - - hallu = HallucinationStateHandler( - response_iterator=resp, function=function_description - ) - - for token in hallu: - assert len(hallu.tokens) >= 0 - assert hallu.hallucination == is_hallucinate_sample + if parameter_gathering: + prefill_prefix = handler_map["Arch-Function"].prefill_prefix + assert any( + response_txt.startswith(prefix) for prefix in prefill_prefix + ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"