final_input = concat(context, embed(x)) x 和 context 拼接在一起
机器翻译实战
首先使用一个小数据集英语与西班牙语翻译, 总计有 11 万条, 来验证我们的模型
seq2seq_attention 实战(Sequence-to-Sequence)
实战步骤
preprocessing data —数据 id 化和 dataset 生成 Tokenizer word level-Tokenizer
build model
encoder 构建(使用 GRU)
attention 构建——实现 Bahdanau —-重点, 难点
decoder 构建:用的 lstm 变种 GRU
loss& optimizer:自定义梯度的更新
train:每次 epoch 调用 train
evaluation (不适合看准确率,使用 bleu)
given sentence, return translated results
visualize results (attention) 注意力分数的可视化
数据预处理
去除西班牙语中的重音
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
import unicodedata import re from sklearn.model_selection import train_test_split
#因为西班牙语有一些是特殊字符,所以我们需要unicode转ascii, # 这样值变小了,因为unicode太大 defunicode_to_ascii(s): #NFD是转换方法,把每一个字节拆开,Mn是重音,所以去除 return''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
#下面我们找个样本测试一下 # 加u代表对字符串进行unicode编码 en_sentence = u"May I borrow this book?" sp_sentence = u"¿Puedo tomar prestado este libro?"
defpreprocess_sentence(w): #变为小写,去掉多余的空格,变成小写,id少一些 w = unicode_to_ascii(w.lower().strip())
# 在单词与跟在其后的标点符号之间插入一个空格 # eg: "he is a boy." => "he is a boy . " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation w = re.sub(r"([?.!,¿])", r" \1 ", w) #因为可能有多余空格,替换为一个空格,所以处理一下 w = re.sub(r'[" "]+', " ", w)
defget_word_idx(ds, mode="src", threshold=2): #载入词表,看下词表长度,词表就像英语字典 word2idx = { "[PAD]": 0, # 填充 token "[BOS]": 1, # begin of sentence "[UNK]": 2, # 未知 token "[EOS]": 3, # end of sentence } idx2word = {value: key for key, value in word2idx.items()} index = len(idx2word) threshold = 1# 出现次数低于此的token舍弃 #如果数据集有很多个G,那是用for循环的,不能' '.join word_list = " ".join([pair[0if mode=="src"else1] for pair in ds]).split() counter = Counter(word_list) #统计词频,counter类似字典,key是单词,value是出现次数 print("word count:", len(counter))
for token, count in counter.items(): if count >= threshold:#出现次数大于阈值的token加入词表 word2idx[token] = index #加入词表 idx2word[index] = token #加入反向词表 index += 1
defencode(self, text_list, padding_first=False, add_bos=True, add_eos=True, return_mask=False): """如果padding_first == True,则padding加载前面,否则加载后面 return_mask: 是否返回mask(掩码),mask用于指示哪些是padding的,哪些是真实的token """ max_length = min(self.max_length, add_eos + add_bos + max([len(text) for text in text_list])) indices_list = [] for text in text_list: indices = [self.word2idx.get(word, self.unk_idx) for word in text[:max_length - add_bos - add_eos]] #如果词表中没有这个词,就用unk_idx代替,indices是一个list,里面是每个词的index,也就是一个样本的index if add_bos: indices = [self.bos_idx] + indices if add_eos: indices = indices + [self.eos_idx] if padding_first:#padding加载前面,超参可以调 indices = [self.pad_idx] * (max_length - len(indices)) + indices else:#padding加载后面 indices = indices + [self.pad_idx] * (max_length - len(indices)) indices_list.append(indices) input_ids = torch.tensor(indices_list) #转换为tensor masks = (input_ids == self.pad_idx).to(dtype=torch.int64) #mask是一个和input_ids一样大小的tensor,0代表token,1代表padding,mask用于去除padding的影响 return input_ids ifnot return_mask else (input_ids, masks)
defdecode(self, indices_list, remove_bos=True, remove_eos=True, remove_pad=True, split=False): text_list = [] for indices in indices_list: text = [] for index in indices: word = self.idx2word.get(index, "[UNK]") #如果词表中没有这个词,就用unk_idx代替 if remove_bos and word == "[BOS]": continue if remove_eos and word == "[EOS]":#如果到达eos,就结束 break if remove_pad and word == "[PAD]":#如果到达pad,就结束 break text.append(word) #单词添加到列表中 text_list.append(" ".join(text) ifnot split else text) #把列表中的单词拼接,变为一个句子 return text_list
# trg_tokenizer.encode([["hello"], ["hello", "world"]], add_bos=True, add_eos=False,return_mask=True) raw_text = ["hello world".split(), "tokenize text datas with batch".split(), "this is a test".split()] indices,mask = trg_tokenizer.encode(raw_text, padding_first=False, add_bos=True, add_eos=True,return_mask=True) decode_text = trg_tokenizer.decode(indices.tolist(), remove_bos=False, remove_eos=False, remove_pad=False) print("raw text"+'-'*10) for raw in raw_text: print(raw) print("mask"+'-'*10) for m in mask: print(m) print("indices"+'-'*10) for index in indices: print(index) print("decode text"+'-'*10) for decode in decode_text: print(decode)