import json import pickle import numpy as np import os import glob import torch from torch.utils import data use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") def get_def2id_defEmbs(def_dir): with open(os.path.join(def_dir, 'def2id'), 'rb') as f: def2id = pickle.load(f) all_def_embs = np.load(os.path.join(def_dir, 'all_def_embs.npy')) all_def_embs = torch.tensor(all_def_embs).to(device).transpose(0, 1).contiguous() # T for (bs, 512)*(512, #) return def2id, all_def_embs def get_pretrained_w2v(path, dim): w2v = dict() with open(path, 'r') as f: next(f) # pass the first line if needed for idx, line in enumerate(f): word, vec = line.strip().split(' ', 1) vec = np.fromstring(vec, sep=' ', dtype=np.float32) if len(vec) != dim: continue if word not in w2v: w2v[word] = vec print("Num pretrained word vetors:", len(w2v)) return w2v def get_voc(voc_path, pre_path, words_path, dim): try: voc = torch.load(voc_path) except FileNotFoundError: print("Voc not found ! Building Voc from pretrained word embedding ...") w2v = get_pretrained_w2v(pre_path, dim) voc = Voc() words = set(open(words_path).read().splitlines()) for w in words: if w in w2v: voc.add_word(w, w2v[w]) torch.save(voc, voc_path) print("Voc size:", voc.n_words) return voc class Voc: def __init__(self): self.word2index = {} self.embedding = [] self.n_words = 0 def add_word(self, word, vec): if word not in self.word2index: self.word2index[word] = self.n_words self.embedding.append(vec) self.n_words += 1 # TODO: # efficiency of _getitem_, preprocess class myDataset(data.Dataset): def __init__(self, params, mode, input_file, ctx_file, def_file, def2id, voc, visualize): self.isVis = visualize self.isRev = params.reverse self.mode = mode self.model_type = params.model_type self.zero_shot = params.zero self.dataset = [] if self.zero_shot: with open(params.unseen_path, 'rb') as f: self.unseen_voc = pickle.load(f) self.preprocess(input_file, ctx_file, def_file, def2id, voc, params.syn_path) self.num_data = len(self.dataset) def preprocess(self, input_file, ctx_file, def_file, def2id, voc, syn_path): ctx_vecs = np.load(ctx_file) # different features: context embedding, ELMo, BERT-base BERT-large def_vecs = np.load(def_file) print('context-dependent embedding:', ctx_vecs.shape) print('definition embedding:', def_vecs.shape) assert len(ctx_vecs) == len(def_vecs), "input error, file sizes mismatch !" if self.isRev: synonyms = open(syn_path).read().splitlines() oov = 0 with open(input_file, 'r') as f: for i, line in enumerate(f): keyword, context, defin = line.split(';') keyword = keyword.strip() context = context.strip() defin = defin.strip() if keyword not in voc.word2index: oov += 1 continue if self.zero_shot: # exclude the words in unseen set during training, and test them only op = (keyword not in self.unseen_voc) if self.mode=='test' else (keyword in self.unseen_voc) if op: continue if self.isRev: syns = set([voc.word2index[w] for w in synonyms[i].split() if w in voc.word2index]) if len(syns) == 0: continue # no synonyms syns.add(voc.word2index[keyword]) self.dataset.append([-1, list(syns), ctx_vecs[i], def_vecs[i], keyword, context, defin]) else: self.dataset.append([def2id[defin], voc.word2index[keyword], ctx_vecs[i], def_vecs[i], keyword, context, defin]) print('Num oov:', oov) def __getitem__(self, index): defID, wordID, ctx_vec, def_vec, keyword, context, defin = self.dataset[index] if self.isVis: return torch.tensor(defID), torch.tensor(wordID), torch.FloatTensor(ctx_vec), torch.FloatTensor(def_vec), keyword, context, defin elif self.isRev: return torch.tensor(wordID), torch.FloatTensor(ctx_vec), torch.FloatTensor(def_vec), keyword, context, defin else: return torch.tensor(defID), torch.tensor(wordID), torch.FloatTensor(ctx_vec), torch.FloatTensor(def_vec) def __len__(self): return self.num_data def get_loader(params, input_file, ctx_file, def_file, def2id, voc, batch_size, mode, visualize=False): dataset = myDataset(params, mode, input_file, ctx_file, def_file, def2id, voc, visualize) dataloader = data.DataLoader(dataset, batch_size, shuffle=(mode=='train'), drop_last=(mode=='train')) print("Get {} dataloader, size: {} !".format(mode, dataset.num_data)) return dataloader