encoding/test_encoding_dsv32.py
| 1 | import json |
| 2 | import copy |
| 3 | |
| 4 | from encoding_dsv32 import encode_messages, parse_message_from_completion_text |
| 5 | |
| 6 | with open("test_input.json", "r") as f: |
| 7 | test_dict = json.load(f) |
| 8 | messages = test_dict["messages"] |
| 9 | messages[0]["tools"] = test_dict["tools"] |
| 10 | |
| 11 | with open("test_output.txt", "r") as f: |
| 12 | gold_prompt = f.read().strip() |
| 13 | |
| 14 | print(messages) |
| 15 | print("=" * 60) |
| 16 | |
| 17 | encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True) |
| 18 | prompt = encode_messages(messages, **encode_config) |
| 19 | print(prompt) |
| 20 | assert prompt == gold_prompt |
| 21 | print("=" * 60) |
| 22 | |
| 23 | tool_call_message = messages[4] |
| 24 | tool_call_prompt = encode_messages([tool_call_message], context=messages[:4], **encode_config) |
| 25 | tool_call_message_wo_id = copy.deepcopy(tool_call_message) |
| 26 | for tool_call in tool_call_message_wo_id["tool_calls"]: |
| 27 | tool_call.pop("id") |
| 28 | parsed_tool_call_message = parse_message_from_completion_text(tool_call_prompt, thinking_mode="thinking") |
| 29 | parsed_tool_call_message.pop("content") |
| 30 | assert tool_call_message_wo_id == parsed_tool_call_message |
| 31 | |
| 32 | thinking_message = messages[-6] |
| 33 | thinking_prompt = encode_messages([thinking_message], context=messages[:-6], **encode_config) |
| 34 | parsed_thinking_message = parse_message_from_completion_text(thinking_prompt, thinking_mode="thinking") |
| 35 | parsed_thinking_message.pop("tool_calls") |
| 36 | assert thinking_message == parsed_thinking_message |
| 37 | |
| 38 | with open("test_input_search_wo_date.json", "r") as f: |
| 39 | search_messages = json.load(f)["messages"] |
| 40 | |
| 41 | with open("test_output_search_wo_date.txt", "r") as f: |
| 42 | search_gold_prompt = f.read().strip() |
| 43 | |
| 44 | search_prompt = encode_messages(search_messages, **encode_config) |
| 45 | assert search_prompt == search_gold_prompt |
| 46 | |
| 47 | with open("test_input_search_w_date.json", "r") as f: |
| 48 | search_messages_w_date = json.load(f)["messages"] |
| 49 | |
| 50 | with open("test_output_search_w_date.txt", "r") as f: |
| 51 | search_gold_prompt_w_date = f.read().strip() |
| 52 | |
| 53 | search_prompt_w_date = encode_messages(search_messages_w_date, **encode_config) |
| 54 | with open("test_output_search_w_date_2.txt", "w") as f: |
| 55 | f.write(search_prompt_w_date) |
| 56 | assert search_prompt_w_date == search_gold_prompt_w_date |