encoding/test_encoding_dsv32.py
2.1 KB · 56 lines · python Raw
1 import json
2 import copy
3
4 from encoding_dsv32 import encode_messages, parse_message_from_completion_text
5
6 with open("test_input.json", "r") as f:
7 test_dict = json.load(f)
8 messages = test_dict["messages"]
9 messages[0]["tools"] = test_dict["tools"]
10
11 with open("test_output.txt", "r") as f:
12 gold_prompt = f.read().strip()
13
14 print(messages)
15 print("=" * 60)
16
17 encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True)
18 prompt = encode_messages(messages, **encode_config)
19 print(prompt)
20 assert prompt == gold_prompt
21 print("=" * 60)
22
23 tool_call_message = messages[4]
24 tool_call_prompt = encode_messages([tool_call_message], context=messages[:4], **encode_config)
25 tool_call_message_wo_id = copy.deepcopy(tool_call_message)
26 for tool_call in tool_call_message_wo_id["tool_calls"]:
27 tool_call.pop("id")
28 parsed_tool_call_message = parse_message_from_completion_text(tool_call_prompt, thinking_mode="thinking")
29 parsed_tool_call_message.pop("content")
30 assert tool_call_message_wo_id == parsed_tool_call_message
31
32 thinking_message = messages[-6]
33 thinking_prompt = encode_messages([thinking_message], context=messages[:-6], **encode_config)
34 parsed_thinking_message = parse_message_from_completion_text(thinking_prompt, thinking_mode="thinking")
35 parsed_thinking_message.pop("tool_calls")
36 assert thinking_message == parsed_thinking_message
37
38 with open("test_input_search_wo_date.json", "r") as f:
39 search_messages = json.load(f)["messages"]
40
41 with open("test_output_search_wo_date.txt", "r") as f:
42 search_gold_prompt = f.read().strip()
43
44 search_prompt = encode_messages(search_messages, **encode_config)
45 assert search_prompt == search_gold_prompt
46
47 with open("test_input_search_w_date.json", "r") as f:
48 search_messages_w_date = json.load(f)["messages"]
49
50 with open("test_output_search_w_date.txt", "r") as f:
51 search_gold_prompt_w_date = f.read().strip()
52
53 search_prompt_w_date = encode_messages(search_messages_w_date, **encode_config)
54 with open("test_output_search_w_date_2.txt", "w") as f:
55 f.write(search_prompt_w_date)
56 assert search_prompt_w_date == search_gold_prompt_w_date