-
Notifications
You must be signed in to change notification settings - Fork 1
/
triple_data.py
130 lines (103 loc) · 5.25 KB
/
triple_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
from typing import Dict, List, Any
from collections import defaultdict
from dataclasses import dataclass
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding
from utils import normalize_instruction
QUERY_KEY = "query"
DOC_KEY = "doc"
def load_medi_data(data_path: str):
with open(data_path, 'r') as f:
training_triples = json.load(f)
task_to_dataset: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for triple in training_triples:
task_name = triple['task_name']
task_to_dataset[task_name].append(triple)
return task_to_dataset
class MultiDataset:
def __init__(
self,
data_path: str,
batch_size: int,
):
self.task_to_dataset = load_medi_data(data_path)
self.batch_size = batch_size
self.task_to_datasize: Dict[str, int] = {task: len(task_data) for task, task_data in self.task_to_dataset.items()}
self.task_data_idxs = self.batched_shuffle(self.task_to_datasize, self.batch_size) # List[Dict[str, Any]]
# only shuffle data indice, i.e., task name and local data idx in this task
def __len__(self):
return len(self.task_data_idxs)
def shuffle_batch(self):
"""Shuld be called at the begin of each epoch"""
self.task_data_idxs = self.batched_shuffle(self.task_to_datasize, self.batch_size)
@staticmethod
def batched_shuffle(task_to_datasize: Dict[str, int], batch_size: int) -> List[Dict[str, Any]]:
task_idxs_batches = [] # List[Dict[str, Any]], list of batches, each batch is a dict with task name and in task batched idxs
for task, data_size in task_to_datasize.items():
shuffled_idxs = np.random.permutation(data_size)
local_batched_shuffled_idxs = [shuffled_idxs[i:i+batch_size] for i in range(0, data_size, batch_size)]
if len(local_batched_shuffled_idxs[-1]) < batch_size:
local_batched_shuffled_idxs.pop()
task_idxs_batches.extend([{"task_name": task, "batch_idxs": idxs} for idxs in local_batched_shuffled_idxs])
random.shuffle(task_idxs_batches)
batched_task_idx = []
for task_batch in task_idxs_batches:
batched_task_idx.extend([{"task": task_batch['task_name'], "idx": idx} for idx in task_batch['batch_idxs']])
return batched_task_idx
def __getitem__(self, idx: int) -> Dict[str, Any]:
task_data_idx = self.task_data_idxs[idx]
task_name = task_data_idx['task']
local_idx = task_data_idx['idx']
example = self.task_to_dataset[task_name][local_idx]
return example
@dataclass
class TripleCollator(DataCollatorWithPadding):
max_q_len: int = 32
max_d_len: int = 128
with_prompt: bool = False
with_instruction: bool = False
mask_instruction_pooling: bool = True
input_keys = ['query', 'pos', 'neg']
key2prompt = {"query": QUERY_KEY, "pos": DOC_KEY, "neg": DOC_KEY}
def __post_init__(self):
assert not (self.with_prompt and self.with_instruction), "Cannot add prompt and instruction in the same time."
def __call__(self, features):
collated_batch = {}
for key in self.input_keys:
texts: Union[List[str], List[List[str]]] = [f[key] for f in features]
# print(text)
if self.with_instruction: # add instruction
assert isinstance(texts[0], list), "No instruction in input text."
instructions = [normalize_instruction(text[0]) for text in texts]
# it seems that some instructions are dropped out in medi data
texts = ['{}: {}'.format(instruction, text[1]) for instruction, text in zip(instructions, texts)]
instruction_mask = self.tokenizer(
instructions,
padding='max_length',
truncation=True,
max_length=self.max_d_len if key == DOC_KEY else self.max_q_len,
return_tensors='pt',
add_special_tokens=True,
return_token_type_ids=False,
return_attention_mask=True,
)['attention_mask'] # Tensor shape (batch_size, max_seq_len)
# instruction_mask[:, 0] = 0 # unmask cls tokens # commented out since this only works for bert-family models
else: # do not add instruction
if isinstance(texts[0], list): # if input format is [instruction, text] with instruction
texts = [text[1] for text in texts] # List[str]
if self.with_prompt: # if add simple prompt
texts = ['{}: {}'.format(key2prompt(key), text) for text in texts]
text_batch = self.tokenizer(
texts,
padding='max_length',
truncation=True,
max_length=self.max_d_len if key == DOC_KEY else self.max_q_len,
return_tensors="pt",
)
if self.with_instruction and self.mask_instruction_pooling:
text_batch["pooling_mask"] = (~(instruction_mask.bool()) & text_batch["attention_mask"].bool())
collated_batch[key] = text_batch
return collated_batch