{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "s_qNSzzyaCbD" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "jmjh290raIky", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "J0Qjg6vuaHNt" }, "source": [ "# 基于注意力的神经机器翻译" ] }, { "cell_type": "markdown", "metadata": { "id": "AOpGoE2T-YXS" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " 在 TensorFlow.org 上查看\n", " \n", " \n", " \n", " 在 Google Colab 运行\n", " \n", " \n", " \n", " 在 GitHub 上查看源代码\n", " \n", " 下载此 notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8dEwzVWg0f-E" }, "source": [ "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的\n", "[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到\n", "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入\n", "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。" ] }, { "cell_type": "markdown", "metadata": { "id": "CiwtNgENbx2g" }, "source": [ "此笔记本训练一个将西班牙语翻译为英语的序列到序列(sequence to sequence,简写为 seq2seq)模型。此例子难度较高,需要对序列到序列模型的知识有一定了解。\n", "\n", "训练完此笔记本中的模型后,你将能够输入一个西班牙语句子,例如 *\"¿todavia estan en casa?\"*,并返回其英语翻译 *\"are you still at home?\"*\n", "\n", "对于一个简单的例子来说,翻译质量令人满意。但是更有趣的可能是生成的注意力图:它显示在翻译过程中,输入句子的哪些部分受到了模型的注意。\n", "\n", "\"spanish-english\n", "\n", "请注意:运行这个例子用一个 P100 GPU 需要花大约 10 分钟。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tnxXKDjq3jEL", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "from sklearn.model_selection import train_test_split\n", "\n", "import unicodedata\n", "import re\n", "import numpy as np\n", "import os\n", "import io\n", "import time" ] }, { "cell_type": "markdown", "metadata": { "id": "wfodePkj3jEa" }, "source": [ "## 下载和准备数据集\n", "\n", "我们将使用 http://www.manythings.org/anki/ 提供的一个语言数据集。这个数据集包含如下格式的语言翻译对:\n", "\n", "```\n", "May I borrow this book?\t¿Puedo tomar prestado este libro?\n", "```\n", "\n", "这个数据集中有很多种语言可供选择。我们将使用英语 - 西班牙语数据集。为方便使用,我们在谷歌云上提供了此数据集的一份副本。但是你也可以自己下载副本。下载完数据集后,我们将采取下列步骤准备数据:\n", "\n", "1. 给每个句子添加一个 *开始* 和一个 *结束* 标记(token)。\n", "2. 删除特殊字符以清理句子。\n", "3. 创建一个单词索引和一个反向单词索引(即一个从单词映射至 id 的词典和一个从 id 映射至单词的词典)。\n", "4. 将每个句子填充(pad)到最大长度。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kRVATYOgJs1b", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 下载文件\n", "path_to_zip = tf.keras.utils.get_file(\n", " 'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n", " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rd0jw-eC3jEh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 将 unicode 文件转换为 ascii\n", "def unicode_to_ascii(s):\n", " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", " if unicodedata.category(c) != 'Mn')\n", "\n", "\n", "def preprocess_sentence(w):\n", " w = unicode_to_ascii(w.lower().strip())\n", "\n", " # 在单词与跟在其后的标点符号之间插入一个空格\n", " # 例如: \"he is a boy.\" => \"he is a boy .\"\n", " # 参考:https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", "\n", " # 除了 (a-z, A-Z, \".\", \"?\", \"!\", \",\"),将所有字符替换为空格\n", " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", "\n", " w = w.rstrip().strip()\n", "\n", " # 给句子加上开始和结束标记\n", " # 以便模型知道何时开始和结束预测\n", " w = ' ' + w + ' '\n", " return w" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "opI2GzOt479E", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "en_sentence = u\"May I borrow this book?\"\n", "sp_sentence = u\"¿Puedo tomar prestado este libro?\"\n", "print(preprocess_sentence(en_sentence))\n", "print(preprocess_sentence(sp_sentence).encode('utf-8'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OHn4Dct23jEm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 1. 去除重音符号\n", "# 2. 清理句子\n", "# 3. 返回这样格式的单词对:[ENGLISH, SPANISH]\n", "def create_dataset(path, num_examples):\n", " lines = io.open(path, encoding='UTF-8').read().strip().split('\\n')\n", "\n", " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", "\n", " return zip(*word_pairs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cTbSbBz55QtF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "en, sp = create_dataset(path_to_file, None)\n", "print(en[-1])\n", "print(sp[-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OmMZQpdO60dt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bIOn8RCNDJXG", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def tokenize(lang):\n", " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n", " filters='')\n", " lang_tokenizer.fit_on_texts(lang)\n", "\n", " tensor = lang_tokenizer.texts_to_sequences(lang)\n", "\n", " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n", " padding='post')\n", "\n", " return tensor, lang_tokenizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eAY9k49G3jE_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def load_dataset(path, num_examples=None):\n", " # 创建清理过的输入输出对\n", " targ_lang, inp_lang = create_dataset(path, num_examples)\n", "\n", " input_tensor, inp_lang_tokenizer = tokenize(inp_lang)\n", " target_tensor, targ_lang_tokenizer = tokenize(targ_lang)\n", "\n", " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" ] }, { "cell_type": "markdown", "metadata": { "id": "GOi42V79Ydlr" }, "source": [ "### 限制数据集的大小以加快实验速度(可选)\n", "\n", "在超过 10 万个句子的完整数据集上训练需要很长时间。为了更快地训练,我们可以将数据集的大小限制为 3 万个句子(当然,翻译质量也会随着数据的减少而降低):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cnxC7q-j3jFD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 尝试实验不同大小的数据集\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)\n", "\n", "# 计算目标张量的最大长度 (max_length)\n", "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4QILQkOs3jFG", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 采用 80 - 20 的比例切分训练集和验证集\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# 显示长度\n", "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lJPmLZGMeD5q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def convert(lang, tensor):\n", " for t in tensor:\n", " if t!=0:\n", " print (\"%d ----> %s\" % (t, lang.index_word[t]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VXukARTDd7MT", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print (\"Input Language; index to word mapping\")\n", "convert(inp_lang, input_tensor_train[0])\n", "print ()\n", "print (\"Target Language; index to word mapping\")\n", "convert(targ_lang, target_tensor_train[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "rgCLkfv5uO3d" }, "source": [ "### 创建一个 tf.data 数据集" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TqHsArVZ3jFS", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", "steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n", "embedding_dim = 256\n", "units = 1024\n", "vocab_inp_size = len(inp_lang.word_index)+1\n", "vocab_tar_size = len(targ_lang.word_index)+1\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qc6-NK1GtWQt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "example_input_batch, example_target_batch = next(iter(dataset))\n", "example_input_batch.shape, example_target_batch.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "TNfHIF71ulLu" }, "source": [ "## 编写编码器 (encoder) 和解码器 (decoder) 模型\n", "\n", "实现一个基于注意力的编码器 - 解码器模型。关于这种模型,你可以阅读 TensorFlow 的 [神经机器翻译 (序列到序列) 教程](https://github.com/tensorflow/nmt)。本示例采用一组更新的 API。此笔记本实现了上述序列到序列教程中的 [注意力方程式](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism)。下图显示了注意力机制为每个输入单词分配一个权重,然后解码器将这个权重用于预测句子中的下一个单词。下图和公式是 [Luong 的论文](https://arxiv.org/abs/1508.04025v5)中注意力机制的一个例子。\n", "\n", "\"attention\n", "\n", "输入经过编码器模型,编码器模型为我们提供形状为 *(批大小,最大长度,隐藏层大小)* 的编码器输出和形状为 *(批大小,隐藏层大小)* 的编码器隐藏层状态。\n", "\n", "下面是所实现的方程式:\n", "\n", "\"attention\n", "\"attention\n", "\n", "本教程的编码器采用 [Bahdanau 注意力](https://arxiv.org/pdf/1409.0473.pdf)。在用简化形式编写之前,让我们先决定符号:\n", "\n", "* FC = 完全连接(密集)层\n", "* EO = 编码器输出\n", "* H = 隐藏层状态\n", "* X = 解码器输入\n", "\n", "以及伪代码:\n", "\n", "* `score = FC(tanh(FC(EO) + FC(H)))`\n", "* `attention weights = softmax(score, axis = 1)`。 Softmax 默认被应用于最后一个轴,但是这里我们想将它应用于 *第一个轴*, 因为分数 (score) 的形状是 *(批大小,最大长度,隐藏层大小)*。最大长度 (`max_length`) 是我们的输入的长度。因为我们想为每个输入分配一个权重,所以 softmax 应该用在这个轴上。\n", "* `context vector = sum(attention weights * EO, axis = 1)`。选择第一个轴的原因同上。\n", "* `embedding output` = 解码器输入 X 通过一个嵌入层。\n", "* `merged vector = concat(embedding output, context vector)`\n", "* 此合并后的向量随后被传送到 GRU\n", "\n", "每个步骤中所有向量的形状已在代码的注释中阐明:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nZ2rI24i3jFg", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", " super(Encoder, self).__init__()\n", " self.batch_sz = batch_sz\n", " self.enc_units = enc_units\n", " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", " self.gru = tf.keras.layers.GRU(self.enc_units,\n", " return_sequences=True,\n", " return_state=True,\n", " recurrent_initializer='glorot_uniform')\n", "\n", " def call(self, x, hidden):\n", " x = self.embedding(x)\n", " output, state = self.gru(x, initial_state = hidden)\n", " return output, state\n", "\n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "60gSVh05Jl6l", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "\n", "# 样本输入\n", "sample_hidden = encoder.initialize_hidden_state()\n", "sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)\n", "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n", "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "umohpBN2OM94", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class BahdanauAttention(tf.keras.layers.Layer):\n", " def __init__(self, units):\n", " super(BahdanauAttention, self).__init__()\n", " self.W1 = tf.keras.layers.Dense(units)\n", " self.W2 = tf.keras.layers.Dense(units)\n", " self.V = tf.keras.layers.Dense(1)\n", "\n", " def call(self, query, values):\n", " # 隐藏层的形状 == (批大小,隐藏层大小)\n", " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n", " # 这样做是为了执行加法以计算分数 \n", " hidden_with_time_axis = tf.expand_dims(query, 1)\n", "\n", " # 分数的形状 == (批大小,最大长度,1)\n", " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n", " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n", " score = self.V(tf.nn.tanh(\n", " self.W1(values) + self.W2(hidden_with_time_axis)))\n", "\n", " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n", " attention_weights = tf.nn.softmax(score, axis=1)\n", "\n", " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n", " context_vector = attention_weights * values\n", " context_vector = tf.reduce_sum(context_vector, axis=1)\n", "\n", " return context_vector, attention_weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k534zTHiDjQU", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "attention_layer = BahdanauAttention(10)\n", "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n", "\n", "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n", "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yJ_B3mhW3jFk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", " super(Decoder, self).__init__()\n", " self.batch_sz = batch_sz\n", " self.dec_units = dec_units\n", " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", " self.gru = tf.keras.layers.GRU(self.dec_units,\n", " return_sequences=True,\n", " return_state=True,\n", " recurrent_initializer='glorot_uniform')\n", " self.fc = tf.keras.layers.Dense(vocab_size)\n", "\n", " # 用于注意力\n", " self.attention = BahdanauAttention(self.dec_units)\n", "\n", " def call(self, x, hidden, enc_output):\n", " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n", " context_vector, attention_weights = self.attention(hidden, enc_output)\n", "\n", " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n", " x = self.embedding(x)\n", "\n", " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n", " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", "\n", " # 将合并后的向量传送到 GRU\n", " output, state = self.gru(x)\n", "\n", " # 输出的形状 == (批大小 * 1,隐藏层大小)\n", " output = tf.reshape(output, (-1, output.shape[2]))\n", "\n", " # 输出的形状 == (批大小,vocab)\n", " x = self.fc(output)\n", "\n", " return x, state, attention_weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P5UY8wko3jFp", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n", "\n", "sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),\n", " sample_hidden, sample_output)\n", "\n", "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))" ] }, { "cell_type": "markdown", "metadata": { "id": "_ch_71VbIRfK" }, "source": [ "## 定义优化器和损失函数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WmTHr5iV3jFr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam()\n", "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True, reduction='none')\n", "\n", "def loss_function(real, pred):\n", " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", " loss_ = loss_object(real, pred)\n", "\n", " mask = tf.cast(mask, dtype=loss_.dtype)\n", " loss_ *= mask\n", "\n", " return tf.reduce_mean(loss_)" ] }, { "cell_type": "markdown", "metadata": { "id": "DMVWzzsfNl4e" }, "source": [ "## 检查点(基于对象保存)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zj8bXQTgNwrF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" ] }, { "cell_type": "markdown", "metadata": { "id": "hpObfY22IddU" }, "source": [ "## 训练\n", "\n", "1. 将 *输入* 传送至 *编码器*,编码器返回 *编码器输出* 和 *编码器隐藏层状态*。\n", "2. 将编码器输出、编码器隐藏层状态和解码器输入(即 *开始标记*)传送至解码器。\n", "3. 解码器返回 *预测* 和 *解码器隐藏层状态*。\n", "4. 解码器隐藏层状态被传送回模型,预测被用于计算损失。\n", "5. 使用 *教师强制 (teacher forcing)* 决定解码器的下一个输入。\n", "6. *教师强制* 是将 *目标词* 作为 *下一个输入* 传送至解码器的技术。\n", "7. 最后一步是计算梯度,并将其应用于优化器和反向传播。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sC9ArXSsVfqn", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "@tf.function\n", "def train_step(inp, targ, enc_hidden):\n", " loss = 0\n", "\n", " with tf.GradientTape() as tape:\n", " enc_output, enc_hidden = encoder(inp, enc_hidden)\n", "\n", " dec_hidden = enc_hidden\n", "\n", " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n", "\n", " # 教师强制 - 将目标词作为下一个输入\n", " for t in range(1, targ.shape[1]):\n", " # 将编码器输出 (enc_output) 传送至解码器\n", " predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n", "\n", " loss += loss_function(targ[:, t], predictions)\n", "\n", " # 使用教师强制\n", " dec_input = tf.expand_dims(targ[:, t], 1)\n", "\n", " batch_loss = (loss / int(targ.shape[1]))\n", "\n", " variables = encoder.trainable_variables + decoder.trainable_variables\n", "\n", " gradients = tape.gradient(loss, variables)\n", "\n", " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", " return batch_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ddefjBMa3jF0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "EPOCHS = 10\n", "\n", "for epoch in range(EPOCHS):\n", " start = time.time()\n", "\n", " enc_hidden = encoder.initialize_hidden_state()\n", " total_loss = 0\n", "\n", " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n", " batch_loss = train_step(inp, targ, enc_hidden)\n", " total_loss += batch_loss\n", "\n", " if batch % 100 == 0:\n", " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", " batch,\n", " batch_loss.numpy()))\n", " # 每 2 个周期(epoch),保存(检查点)一次模型\n", " if (epoch + 1) % 2 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", "\n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / steps_per_epoch))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" ] }, { "cell_type": "markdown", "metadata": { "id": "mU3Ce8M6I3rz" }, "source": [ "## 翻译\n", "\n", "* 评估函数类似于训练循环,不同之处在于在这里我们不使用 *教师强制*。每个时间步的解码器输入是其先前的预测、隐藏层状态和编码器输出。\n", "* 当模型预测 *结束标记* 时停止预测。\n", "* 存储 *每个时间步的注意力权重*。\n", "\n", "请注意:对于一个输入,编码器输出仅计算一次。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EbQpyYs13jF_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def evaluate(sentence):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", "\n", " sentence = preprocess_sentence(sentence)\n", "\n", " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n", " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n", " maxlen=max_length_inp,\n", " padding='post')\n", " inputs = tf.convert_to_tensor(inputs)\n", "\n", " result = ''\n", "\n", " hidden = [tf.zeros((1, units))]\n", " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input,\n", " dec_hidden,\n", " enc_out)\n", "\n", " # 存储注意力权重以便后面制图\n", " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.index_word[predicted_id] + ' '\n", "\n", " if targ_lang.index_word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", "\n", " # 预测的 ID 被输送回模型\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s5hQWlbN3jGF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 注意力权重制图函数\n", "def plot_attention(attention, sentence, predicted_sentence):\n", " fig = plt.figure(figsize=(10,10))\n", " ax = fig.add_subplot(1, 1, 1)\n", " ax.matshow(attention, cmap='viridis')\n", "\n", " fontdict = {'fontsize': 14}\n", "\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sl9zUHzg3jGI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def translate(sentence):\n", " result, sentence, attention_plot = evaluate(sentence)\n", "\n", " print('Input: %s' % (sentence))\n", " print('Predicted translation: {}'.format(result))\n", "\n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" ] }, { "cell_type": "markdown", "metadata": { "id": "n250XbnjOaqP" }, "source": [ "## 恢复最新的检查点并验证" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UJpT9D5_OgP6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 恢复检查点目录 (checkpoint_dir) 中最新的检查点\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrAM0FDomq3E", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "translate(u'hace mucho frio aqui.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zSx2iM36EZQZ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "translate(u'esta es mi vida.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A3LLCx3ZE0Ls", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "translate(u'¿todavia estan en casa?')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DUQVLVqUE1YW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 错误的翻译\n", "translate(u'trata de averiguarlo.')" ] }, { "cell_type": "markdown", "metadata": { "id": "RTe5P5ioMJwN" }, "source": [ "## 下一步\n", "\n", "* [下载一个不同的数据集](http://www.manythings.org/anki/)实验翻译,例如英语到德语或者英语到法语。\n", "* 实验在更大的数据集上训练,或者增加训练周期。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "nmt_with_attention.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }