|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "attachments": {}, |
| 5 | + "cell_type": "markdown", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# 层级知识点预测" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "code", |
| 13 | + "execution_count": null, |
| 14 | + "metadata": {}, |
| 15 | + "outputs": [], |
| 16 | + "source": [ |
| 17 | + "import os\n", |
| 18 | + "import yaml\n", |
| 19 | + "import tqdm\n", |
| 20 | + "import torch\n", |
| 21 | + "import numpy as np\n", |
| 22 | + "from EduNLP.Pretrain import BertTokenizer\n", |
| 23 | + "from EduNLP.ModelZoo.bert import BertForKnowledgePrediction\n", |
| 24 | + "from EduNLP.Pretrain import finetune_bert_for_knowledge_prediction\n", |
| 25 | + "from EduNLP.ModelZoo import load_items\n", |
| 26 | + "\n", |
| 27 | + "from utils import compute_perfs_per_layer, get_onehot_label_topk, metric, compute_perfs" |
| 28 | + ] |
| 29 | + }, |
| 30 | + { |
| 31 | + "cell_type": "code", |
| 32 | + "execution_count": null, |
| 33 | + "metadata": {}, |
| 34 | + "outputs": [], |
| 35 | + "source": [ |
| 36 | + "train_data = load_items(\"/path/to/data/train.jsonl\")\n", |
| 37 | + "test_data = load_items(\"/path/to/data/test.jsonl\")\n", |
| 38 | + "\n", |
| 39 | + "pretrained_model_dir =\"/path/to/bert/checkpoint\"\n", |
| 40 | + "checkpoint_dir = \"/path/to/knowledge_model/checkpoint\"" |
| 41 | + ] |
| 42 | + }, |
| 43 | + { |
| 44 | + "attachments": {}, |
| 45 | + "cell_type": "markdown", |
| 46 | + "metadata": {}, |
| 47 | + "source": [ |
| 48 | + "## 训练" |
| 49 | + ] |
| 50 | + }, |
| 51 | + { |
| 52 | + "cell_type": "code", |
| 53 | + "execution_count": null, |
| 54 | + "metadata": {}, |
| 55 | + "outputs": [], |
| 56 | + "source": [ |
| 57 | + "# 以bert为例\n", |
| 58 | + "data_params = {\n", |
| 59 | + " \"stem_key\": \"ques_content\",\n", |
| 60 | + " \"label_key\": \"know_list\"\n", |
| 61 | + "}\n", |
| 62 | + "train_params = {\n", |
| 63 | + " \"num_train_epochs\": 1,\n", |
| 64 | + " \"per_device_train_batch_size\": 2,\n", |
| 65 | + " \"per_device_eval_batch_size\": 2,\n", |
| 66 | + " \"no_cuda\": True,\n", |
| 67 | + "}\n", |
| 68 | + "model_params = {\n", |
| 69 | + " \"num_classes_list\": [10, 27, 963],\n", |
| 70 | + " \"num_total_classes\": 1000,\n", |
| 71 | + "}\n", |
| 72 | + " \n", |
| 73 | + "\n", |
| 74 | + "\"\"\"\n", |
| 75 | + "数据格式:\n", |
| 76 | + "{\n", |
| 77 | + " 'ques_content': 'question...',\n", |
| 78 | + " 'know_list': [lay_1_id, lay_2_id, lay_3_id]\n", |
| 79 | + "}\n", |
| 80 | + "\"\"\"\n", |
| 81 | + "\n", |
| 82 | + "# train without eval_items\n", |
| 83 | + "finetune_bert_for_knowledge_prediction(\n", |
| 84 | + " train_data,\n", |
| 85 | + " checkpoint_dir,\n", |
| 86 | + " pretrained_model=pretrained_model_dir,\n", |
| 87 | + " train_params=train_params,\n", |
| 88 | + " data_params=data_params,\n", |
| 89 | + " model_params=model_params\n", |
| 90 | + ")" |
| 91 | + ] |
| 92 | + }, |
| 93 | + { |
| 94 | + "attachments": {}, |
| 95 | + "cell_type": "markdown", |
| 96 | + "metadata": {}, |
| 97 | + "source": [ |
| 98 | + "## 加载模型和评估数据" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "code", |
| 103 | + "execution_count": null, |
| 104 | + "metadata": {}, |
| 105 | + "outputs": [], |
| 106 | + "source": [ |
| 107 | + "# 针对多标签任务处理标签\n", |
| 108 | + "class EvalDataset(torch.utils.data.Dataset):\n", |
| 109 | + " def __init__(self, data) -> None:\n", |
| 110 | + " self.data = data\n", |
| 111 | + " self.num_classes = model_params['num_classes_list']\n", |
| 112 | + " self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_dir)\n", |
| 113 | + "\n", |
| 114 | + " def __getitem__(self, idx):\n", |
| 115 | + " text, labels = self.data[idx][\"ques_content\"], self.data[idx][\"know_list\"]\n", |
| 116 | + " encodings = self.tokenizer(text, padding='max_length', truncation=True, return_tensors='pt')\n", |
| 117 | + " for k, v in encodings.items():\n", |
| 118 | + " encodings[k] = torch.squeeze(v, dim=0)\n", |
| 119 | + " one_hot_labels = [1. if idx in labels else 0. for idx in range(self.num_classes)]\n", |
| 120 | + " return encodings, torch.FloatTensor(one_hot_labels)\n", |
| 121 | + "\n", |
| 122 | + " def __len__(self):\n", |
| 123 | + " return len(self.data)\n", |
| 124 | + "\n", |
| 125 | + "test_dataset = EvalDataset(test_data)\n", |
| 126 | + "eval_dataloader = EvalDataset(\n", |
| 127 | + " test_data,\n", |
| 128 | + " batch_size=1,\n", |
| 129 | + " shuffle=False,\n", |
| 130 | + " num_workers=4,\n", |
| 131 | + ")\n", |
| 132 | + "\n", |
| 133 | + "model = BertForKnowledgePrediction.from_pretrained(checkpoint_dir)" |
| 134 | + ] |
| 135 | + }, |
| 136 | + { |
| 137 | + "attachments": {}, |
| 138 | + "cell_type": "markdown", |
| 139 | + "metadata": {}, |
| 140 | + "source": [ |
| 141 | + "## 评估" |
| 142 | + ] |
| 143 | + }, |
| 144 | + { |
| 145 | + "cell_type": "code", |
| 146 | + "execution_count": null, |
| 147 | + "metadata": {}, |
| 148 | + "outputs": [], |
| 149 | + "source": [ |
| 150 | + "device = \"cuda\" if not train_params[\"no_cuda\"] else \"cpu\"\n", |
| 151 | + "\n", |
| 152 | + "# 层级知识标签-配置信息\n", |
| 153 | + "levels = len(model_params[\"num_classes_list\"])\n", |
| 154 | + "classes_offset_list = [0, 10, 37]\n", |
| 155 | + "classes_border_list = [[0, 9], [10, 36], [37, 1000]] # 层级id边界\n", |
| 156 | + "hierarchy_dict = {} # child_know_id_to_parent_know_id\n", |
| 157 | + "\n", |
| 158 | + "# 评估top_k结果\n", |
| 159 | + "top_k_list=[10, 20, 30]\n", |
| 160 | + "\n", |
| 161 | + "model.eval()\n", |
| 162 | + "perfs_per_layer = [np.array([0 for _ in range(4)], dtype=np.int32) for _ in range(levels)]\n", |
| 163 | + "total_perfs = np.array([0 for _ in range(4)], dtype=np.int32)\n", |
| 164 | + "\n", |
| 165 | + "k_total_perfs_list = [ np.array([0 for _ in range(4)], dtype=np.int32)for _ in range(len(top_k_list)) ]\n", |
| 166 | + "for i, eval_batch in tqdm.tqdm(enumerate(eval_dataloader)):\n", |
| 167 | + " input_data, eval_batch_labels = eval_batch\n", |
| 168 | + " input_data = input_data.to(device)\n", |
| 169 | + " _, output_logits = model(**input_data)\n", |
| 170 | + "\n", |
| 171 | + " local_perfs_per_layer, local_perfs = compute_perfs_per_layer(\n", |
| 172 | + " output_logits.cpu().detach().numpy(),\n", |
| 173 | + " eval_batch_labels.cpu().detach().numpy(),\n", |
| 174 | + " hierarchy_dict,\n", |
| 175 | + " classes_border_list,\n", |
| 176 | + " keep_consistency=True\n", |
| 177 | + " )\n", |
| 178 | + " perfs_per_layer = [perfs_per_layer[idx] + local_perfs_per_layer[idx] for idx in range(levels)]\n", |
| 179 | + " total_perfs += local_perfs\n", |
| 180 | + " \n", |
| 181 | + " # for recall@k\n", |
| 182 | + " for i_k, k in enumerate(top_k_list):\n", |
| 183 | + " pred_topk = get_onehot_label_topk(\n", |
| 184 | + " classes_border_list, classes_offset_list, scores_list=output_logits.cpu().detach().numpy(), top_num=k)\n", |
| 185 | + " flat_pred_topk = np.array([x[3] for x in pred_topk])\n", |
| 186 | + " k_total_perfs = compute_perfs(flat_pred_topk, eval_batch_labels.cpu().detach().numpy().tolist())\n", |
| 187 | + " k_total_perfs_list[i_k] += k_total_perfs\n", |
| 188 | + "\n", |
| 189 | + "# metric for overall\n", |
| 190 | + "micro_precision, micro_recall, micro_f1, total_acc = metric(*total_perfs)\n", |
| 191 | + "print(f\"Eval Results: Micro-Precision: {micro_precision:.4f}, \"\n", |
| 192 | + " + f\"Micro-Recall: {micro_recall:.4f}, Micro-F1: {micro_f1:.4f}, Acc: {total_acc:.4f}\")\n", |
| 193 | + "\n", |
| 194 | + "# metrics for per top_k\n", |
| 195 | + "for i_k, k_total_perfs in enumerate(k_total_perfs_list):\n", |
| 196 | + " k = top_k_list[i_k]\n", |
| 197 | + " precision, recall, f1, acc = metric(*k_total_perfs)\n", |
| 198 | + " print(f\"TOPK={k}: Precision@{k}: {precision:.4f}, Recall@{k}: {recall:.4f}, F1@{k}: {f1:.4f}, Acc@{k}: {acc:.4f}\")\n", |
| 199 | + "\n", |
| 200 | + "# metrics for per layer\n", |
| 201 | + "for layer_idx, perfs in enumerate(perfs_per_layer):\n", |
| 202 | + " precision, recall, f1, acc = metric(*perfs)\n", |
| 203 | + " print(f\"Layer {layer_idx + 1}: Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {acc:.4f}\")" |
| 204 | + ] |
| 205 | + } |
| 206 | + ], |
| 207 | + "metadata": { |
| 208 | + "language_info": { |
| 209 | + "name": "python" |
| 210 | + }, |
| 211 | + "orig_nbformat": 4 |
| 212 | + }, |
| 213 | + "nbformat": 4, |
| 214 | + "nbformat_minor": 2 |
| 215 | +} |
0 commit comments