TensorFlow从1到2(十)带注意力机制的神经网络机器翻译

俺踏月色而来 2019-04-30 10:05:00 阅读数:151 评论数:0 收藏数:0

基本概念

机器翻译和语音识别是最早开展的两项人工智能研究。今天也取得了最显著的商业成果。
早先的机器翻译实际脱胎于电子词典,能力更擅长于词或者短语的翻译。那时候的翻译通常会将一句话打断为一系列的片段,随后通过复杂的程序逻辑对每一个片段进行翻译,最终组合在一起。所得到的翻译结果应当说似是而非,最大的问题是可读性和连贯性非常差。
实际从机器学习的观点来讲,这种翻译方式,也不符合人类在做语言翻译时所做的动作。其实以神经网络为代表的机器学习,更多的都是在“模仿”人类的行为习惯。
一名职业翻译通常是这样做:首先完整听懂要翻译的语句,将语义充分理解,随后把理解到的内容,用目标语言复述出来。
而现在的机器翻译,也正是这样做的,谷歌的seq2seq是这一模式的开创者。
如果用计算机科学的语言来说,这一过程很像一个编解码过程。原始的语句进入编码器,得到一组用于代表原始语句“内涵”的数组。这些数组中的数字就是原始语句所代表的含义,只是这个含义人类无法读懂,是需要由神经网络模型去理解的。随后解码过程,将“有含义的数字”解码为对应的目标语言。从而完成整个翻译过程。这样的得到的翻译结果,非常流畅,具有更好的可读性。

(图片来自谷歌NMT文档)

注意力机制是人类特有的大脑思维方式,比如看到下面这幅照片:

(图片来自互联网)
照片的内容实际很多,甚至如果从数学上说,背景树林的复杂度要高于前景。但看到照片的人,都会先注意到迎面而来的飞盘,随后是投掷者,接着是图像右侧的小孩子。其它的信息都被忽略了。
这是人类在上万年的进化中所形成的本能。对于快速向自己移动的物体首先会看到、识别危险、并且快速应对。接着是可能对自己造成威胁的同类或者生物。为了做到集注,不得不忽略看起来无关紧要的东西。
在机器学习中引入注意力模型,在图像处理、机器翻译、策略博弈等各个领域中都有应用。这里的注意力机制有两个作用:一是降低模型的复杂度或者计算量,把主要资源分配给更重要的内容。二是对应把最相关的输入导出到相关的输出,更有针对性的得到结果。

在机器翻译领域,前面我们已经确定和解释了编码、解码模型。那么第二点的输入输出相关性就显得更重要。
我们举例来说明:比如英文“I love you”,翻译为中文是“我爱你”。在一个编码解码模型中,首先由编码器处理“I love you”,从而得到中间语义,比如我们称为C:

    C = Encoder("I love you")  

解码的时候,如果没有注意力机制,那序列输出则是:

    "我" = Decoder(C)  
    "爱" = Decoder(C)  
    "你" = Decoder(C)  

因为C相当于“I love you”三个单词共同的作用。那么解码的时候,每一个字的输出,都相当于3个单词共同作用的结果。这显然是不合理的,而且也不大可能得到一个理想、顺畅的结果。
一个理想的解码模型应当类似这样的方式:

    "我" = Decoder(C+"I")  
    "爱" = Decoder(C+"love")  
    "你" = Decoder(C+"you")  

当然,机器学习不是人。人通过大量的学习、经验的积累,一眼就能看出来“I”对应翻译成“我”,“love”翻译成“爱”。机器不可能提前知道这一切,所以我们比较切实的方法,只能是增加一套权重逻辑,在不同的翻译处理中,对应不同的权重属性。这就好像下面这样的方式:

    "我" = Decoder(C+0.8x"I"+0.1x"love"+0.2x"you")  
    "爱" = Decoder(C+0.1x"I"+0.7x"love"+0.1x"you")  
    "你" = Decoder(C+0.2x"I"+0.1x"love"+0.8x"you")  

没错了,这个权重值,比如翻译“我”的时候的权重序列:(0.8,0.1,0.2),就是注意力机制。在翻译某个目标单词输出的时候,通过注意力机制,模型集注在对应的某个输入单词。
当然,注意力机制还包含上面示意性的表达式没有显示出来的一个重要操作:结合解码器的当前状态、和编码器输入内容之后的状态,在每一次翻译解码操作中更新注意力的权重值。

翻译模型

回到上面的编解码模型示意图。编码器、解码器在我们的机器学习中,实际都是神经网络模型。那么把上面的示意图展开,一个没有注意力机制的编码、解码翻译模型是这个样子:

(图片来自谷歌NMT文档)

随后,我们为这个模型增加解码时候的权重机制。模型在处理每个单词输出的时候,会在权重的帮助下,把重点放在对应的输入单词上。示意图如下:

(图片来自谷歌NMT文档)

最终,结合权重生成的过程,成为完整的注意力机制。注意力机制主要作用于解码,在每一个输出步骤中都要重新计算注意力权重,并更新到解码模型从而对输出产生影响。模型的示意图如下:

(图片来自谷歌NMT文档)
图片中注意力权重的来源和去向箭头,要注意看清楚,这对你下面阅读实现的代码会很有帮助。

样本及样本预处理

前面的编解码模型示意图,还有模拟的表达式,当然都做了很多简化。实际上中间还有很多工作要做,首先是翻译样本库。

本例中使用http://www.manythings.org/anki/提供的英文对比西班牙文样本库,网站上还有很多其它语言的对比样本可以下载,有兴趣的读者不妨在做完这个练习后尝试一下其它语言的机器翻译。
这个样本是文本格式,包含很多行,每一行都是一个完整的句子,包含英文和西班牙文两部分,两种文字之间使用制表符隔开,比如:

May I borrow this book? ¿Puedo tomar prestado este libro?

对于样本库,我们要进行以下几项预处理:

  • 读取样本库,建立数据集。每一行的样本按语言分为两个部分。
  • 为每一句样本,增加开始标志<start>和结束标志<end>。看过《从锅炉工到AI专家(10)》的话,你应当理解这种做法。经过训练后,模型会根据这两个标志作为翻译的开始和结束。

做完上面的处理后,刚才的那行样本看起来会是这个样子:

<start> may i borrow this book ? <end>
<start> ¿ puedo tomar prestado este libro ? <end>

注意标点符号也是语言的组成部分,每个部分用空格隔开,都需要单独数字化。所以你能看到,上面的两行例句,标点符号之前也添加了空格。

  • 进行数据清洗,去掉不支持的字符。
  • 把单词数字化,建立从单词到数字和从数字到单词的对照表。
  • 设置一个句子的最大长度,把每个句子按照最大长度在句子的后端补齐。

一行句子数字化之后,编码同单词之间的对照关系可能类似下面的样子:

Input Language; index to word mapping
1 ----> <start>
8 ----> no
38 ----> puedo
804 ----> confiar
20 ----> en
1000 ----> vosotras
3 ----> .
2 ----> <end>

Target Language; index to word mapping
1 ----> <start>
4 ----> i
25 ----> can
12 ----> t
345 ----> trust
6 ----> you
3 ----> .
2 ----> <end>

你可能注意到了,“can't”中的单引号作为不支持的字符被过滤掉了,不过你放心,这并不会影响模型的训练。当然在一个完善的翻译系统中,这样的字符都应当单独处理,本例中就忽略了。

模型构建

本例中使用了编码器、解码器、注意力机制三个网络模型,都继承自keras.Model,属于三个自定义的Keras模型。
三个模型共同组成了完整的翻译模型。完整模型的组装,是在训练过程和翻译(预测)过程中,通过相应子程序把他们组装在一起的。这是因为它们三者之间的逻辑机制相对比较复杂。无法用前面常用的keras.models.Sequential方法直接耦合在一起。
自定义Keras模型在本系列中是第一次遇到,所以着重讲一下。实现自定义模型有三个基本要求:

  • 继承自keras.Model类。
  • 实现__init__方法,用于实现类的初始化,同所有面向对象的语言一样,这里主要完成基类和类成员的初始化工作。
  • 实现call方法,这是主要的计算逻辑。模型接入到神经网络之后,训练逻辑和预测逻辑,都通过逐层调用call方法来完成计算。方法中可以使用keras中原有的网络模型和自己的计算通过组合来完成工作。

自定义模型之所以有这些要求,主要是为了自定义的模型,可以跟Keras原生层一样,互相兼容,支持多种模型的组合、互联,从而共同形成更复杂的模型。

Encoder/Decoder主体都使用GRU网络,读起来应当比较容易理解。有需要的话,复习一下《从锅炉工到AI专家(10)》
注意力机制的BahdanauAttention模型就很令人费解了,困惑的关键在于其中的算法。算法的计算部分只有两行代码,代码本身都知道是在做什么,但完全不明白组合在一起是什么功能以及为什么这样做。其实阅读由数学公式推导、转换而来的程序代码都有这种感觉。所以现在很多的知识保护,根本不在于源代码,而在于公式本身。没有公式,很多源代码非常难以读懂。
这部分推荐阅读Dzmitry Bahdanau的论文《Neural Machine Translation by Jointly Learning to Align and Translate》和之后Minh-Thang Luong改进的算法《Effective Approaches to Attention-based Neural Machine Translation》。论文中对于理论做了详尽解释,也有公式的推导过程。
这里的BahdanauAttention模型实际就是公式的程序实现。如果精力不够的话,死记公式也算一种学习方法。

训练和预测

我们以往碰到的模型,训练和预测基本都是一行代码,几乎没有什么需要解释的。
今天的模型涉及了带有注意力机制的自定义模型,主要的逻辑,是通过程序代码,在训练和评估子程序中把模型组合起来完成的。
程序如果只是编码器和解码器串联的逻辑,完全可以同以前一样,一条keras.Sequential函数完成组装,那就一点难度没有了。而加上注意力机制,复杂度高了很多,也是最难理解的地方。做一个简单的分析:

  • 编码器Encoder是一次整句编码,得到一个enc_output。enc_output相当于模型对整句语义的理解。
  • 解码器Decoder是逐个单词输入,逐个单词输出的。训练时,输入序列由<start>起始标志开始,到<end>标志结束。预测时,没有人知道这一句翻译的结果是多少个单词,就是逐个获取Decoder的输出,直到得到一个<end>标志。
  • Encoder和Decoder都引出了隐藏层,用于计算注意力权重。keras.layers.GRU的state输出其实就是隐藏层,平时这个参数我们是用不到的。
  • 对于每一个翻译的输出词,注意力对其影响就是通过attention_weights * values,然后将结果跟前一个输出词一起作为Decoder的GRU输入,values实际就是编码器输出enc_output。
  • Decoder输出上一个词时候的隐藏层,跟enc_output一起通过公式计算,得到下一个词的注意力权重attention_weights。在第一次循环的时候Decoder还没有输出过隐藏层,这时候使用的是Encoder的隐藏层。
  • 注意力权重attention_weights从程序逻辑上并不需要引出,程序中在Decoder中输出这个值是为了绘制注意力映射图,帮助你更好的理解注意力机制。所以如果是在这个基础上做翻译系统,输出权重值到模型外部是不需要的。
  • 为了匹配各个网络的不同维度和不同形状,注意力机制的计算逻辑和注意力权重经过了各种维度变形。Decoder的输入虽然是一个词,但也需要扩展成一批词的第一个元素(也是唯一一个元素),这个跟我们以前的模型在预测时所做的是完全一样的。

完整源码

下面是完整的可执行源代码,请参考注释阅读:

#!/usr/bin/env python3

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import unicodedata
import re
import numpy as np
import os
import io
import time
import sys

# 如果命令行增加了参数'train'则进入训练模式,否则按照翻译模式执行
TRAIN = False
if len(sys.argv) == 2 and sys.argv[1] == 'train':
    TRAIN = True

# 下载样本集,下载后自动解压。数据保存在路径:~/.keras/datasets/
path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip',
    origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
    extract=True)
# 指向解压后的样本文件
path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"

# 将文本从unicode编码转换为ascii编码
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn')

# 对所有的句子做预处理
def preprocess_sentence(w):
    w = unicode_to_ascii(w.lower().strip())

    # 在单词和标点之间增加空格
    # 比如: "he is a boy." => "he is a boy ."
    # 参考: https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
    w = re.sub(r"([?.!,¿])", r" \1 ", w)
    w = re.sub(r'[" "]+', " ", w)

    # 用空格替换掉除了大小写字母和"."/ "?"/ "!"/ ","之外的字符
    w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
    # 截断两端的空白
    w = w.rstrip().strip()

    # 在句子两端增加开始和结束标志
    # 这样经过训练后,模型知道什么时候开始和什么时候结束
    w = '<start> ' + w + ' <end>'
    return w

# 载入样本集,对句子进行预处理
# 最终返回(英文,西班牙文)这样的配对元组
def create_dataset(path, num_examples):
    lines = io.open(path, encoding='UTF-8').read().strip().split('\n')

    word_pairs = [[preprocess_sentence(w) for w in l.split('\t')]  for l in lines[:num_examples]]

    return zip(*word_pairs)
# 至此的输出为:
# <start> go away ! <end>
# <start> salga de aqui ! <end>
# 这样的形式。

# 获取最长的句子长度
def max_length(tensor):
    return max(len(t) for t in tensor)

# 将单词数字化之后的数字<->单词双向对照表
def tokenize(lang):
    lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
        filters='')
    lang_tokenizer.fit_on_texts(lang)

    tensor = lang_tokenizer.texts_to_sequences(lang)

    tensor = tf.keras.preprocessing.sequence.pad_sequences(
        tensor,
        padding='post')

    return tensor, lang_tokenizer

def load_dataset(path, num_examples=None):
    # 载入样本,两种语言分别保存到两个数组
    targ_lang, inp_lang = create_dataset(path, num_examples)
    # 把句子数字化,两种语言是两套对照编码
    input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
    target_tensor, targ_lang_tokenizer = tokenize(targ_lang)

    return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer

# 训练的样本集数量,越大翻译效果越好,但训练耗时越长
num_examples = 80000
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)
# 至此,input_tensor/target_tensor 是数字化之后的样本(数字数组)
# inp_lang/targ_lang 是数字<->单词编码对照表
# 计算两种语言中最长句子的长度
max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)

# 将样本按照8:2分为训练集和验证集
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)

##############################################

BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_tensor_train)//BATCH_SIZE
embedding_dim = 256
units = 1024
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1

dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

# 编码器模型
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
                                    self.enc_units, 
                                    return_sequences=True, 
                                    return_state=True, 
                                    recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state=hidden)
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)

# 注意力模型
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values):
        # query为上次的GRU隐藏层
        # values为编码器的编码结果enc_output
        hidden_with_time_axis = tf.expand_dims(query, 1)

        # 计算注意力权重值
        score = self.V(tf.nn.tanh(
            self.W1(values) + self.W2(hidden_with_time_axis)))

        attention_weights = tf.nn.softmax(score, axis=1)

        # 使用注意力权重*编码器输出作为返回值,将来会作为解码器的输入
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

# 解码器模型
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            self.dec_units, 
            return_sequences=True,
            return_state=True, 
            recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.dec_units)

    def call(self, x, hidden, enc_output):
        # 使用上次的隐藏层(第一次使用编码器隐藏层)、编码器输出计算注意力权重
        context_vector, attention_weights = self.attention(hidden, enc_output)

        x = self.embedding(x)

        # 将上一循环的预测结果跟注意力权重值结合在一起作为本次的GRU网络输入
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # state实际是GRU的隐藏层
        output, state = self.gru(x)

        output = tf.reshape(output, (-1, output.shape[2]))

        x = self.fc(output)

        return x, state, attention_weights

decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)


optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 损失函数
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

# 保存中间训练结果
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

# 一次训练
@tf.function
def train_step(inp, targ, enc_hidden):
    loss = 0

    with tf.GradientTape() as tape:
        # 输入源语言句子进行编码
        enc_output, enc_hidden = encoder(inp, enc_hidden)
        # 保留编码器隐藏层用于第一次的注意力权重计算
        dec_hidden = enc_hidden

        # 解码器第一次的输入必定是<start>,targ_lang.word_index['<start>']是转换为对应的数字编码
        dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)       

        # 循环整个目标句子(用于对比每一次解码器输出同样本的对比)
        for t in range(1, targ.shape[1]):
            # 使用本单词、隐藏层、编码器输出共同预测下一个单词,同事保留本次的隐藏层作为下一次输入
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
            # 计算损失值,最终的损失值是整个句子所有单词损失值的合计
            loss += loss_function(targ[:, t], predictions)

            # 在训练时,每次解码器的输入并不是上次解码器的输出,而是样本目标语言对应单词
            # 这称为teach forcing
            dec_input = tf.expand_dims(targ[:, t], 1)

    # 所有单词的平均损失值
    batch_loss = (loss / int(targ.shape[1]))
    # 最终的训练参量是编码器和解码的集合
    variables = encoder.trainable_variables + decoder.trainable_variables
    # 根据代价值计算下一次的参量值
    gradients = tape.gradient(loss, variables)
    # 将新的参量应用到模型
    optimizer.apply_gradients(zip(gradients, variables))

    return batch_loss

def training():
    EPOCHS = 10

    for epoch in range(EPOCHS):
        start = time.time()
        # 初始化隐藏层和损失值
        enc_hidden = encoder.initialize_hidden_state()
        total_loss = 0

        # 一个批次的训练
        for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
            batch_loss = train_step(inp, targ, enc_hidden)
            total_loss += batch_loss

        # 每100次显示一下模型损失值
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(
                                                        epoch + 1,
                                                        batch,
                                                        batch_loss.numpy()))
        # 每两次迭代保存一次数据
        if (epoch + 1) % 2 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        # 显示每次迭代的损失值和消耗时间
        print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                            total_loss / steps_per_epoch))
        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

# 根据命令行参数选择本次是否进行训练
if TRAIN:
    training()
################################################

# 评估(翻译)一行句子
def evaluate(sentence):
    # 清空注意力图
    attention_plot = np.zeros((max_length_targ, max_length_inp))
    # 句子预处理
    sentence = preprocess_sentence(sentence)
    # 句子数字化
    inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
    # 按照最长句子长度补齐
    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], 
                                                           maxlen=max_length_inp, 
                                                           padding='post')
    inputs = tf.convert_to_tensor(inputs)

    result = ''

    # 句子做编码
    hidden = [tf.zeros((1, units))]
    enc_out, enc_hidden = encoder(inputs, hidden)

    # 编码器隐藏层作为第一次解码器的隐藏层值
    dec_hidden = enc_hidden
    # 解码第一个单词必然是<start>,表示启动解码
    dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)

    # 假设翻译结果不超过最长的样本句子
    for t in range(max_length_targ):
        # 逐个单词翻译
        predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                             dec_hidden,
                                                             enc_out)

        # 保留注意力权重用于绘制注意力图
        # 注意每次循环的每个单词注意力权重是不同的
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()

        # 得到预测值
        predicted_id = tf.argmax(predictions[0]).numpy()

        # 从数字查表转换为对应单词,累加到上一次结果,最终组成句子
        result += targ_lang.index_word[predicted_id] + ' '

        # 如果是<end>表示翻译结束
        if targ_lang.index_word[predicted_id] == '<end>':
            return result, sentence, attention_plot

        # 上次的预测值,将作为下次解码器的输入
        dec_input = tf.expand_dims([predicted_id], 0)
    # 如果超过样本中最长的句子仍然没有翻译结束标志,则返回当前所有翻译结果
    return result, sentence, attention_plot

# 绘制注意力图
def plot_attention(attention, sentence, predicted_sentence):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')

    fontdict = {'fontsize': 14}

    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

    plt.show()

# 翻译一句文本
def translate(sentence):
    result, sentence, attention_plot = evaluate(sentence)

    print('Input: %s' % (sentence))
    print('Predicted translation: {}'.format(result))

    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
    plot_attention(attention_plot, sentence.split(' '), result.split(' '))

# 恢复保存的训练结果
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# 测试以下翻译
translate(u'hace mucho frio aqui.')
translate(u'esta es mi vida.')
translate(u'¿todavia estan en casa?')
# 据说这句话的翻译结果不对,不懂西班牙文,不做评论
translate(u'trata de averiguarlo.')

第一次执行的时候要加参数tain:

$ ./translate_spa2en.py train
Epoch 1 Batch 0 Loss 4.5296
Epoch 1 Batch 100 Loss 2.2811
Epoch 1 Batch 200 Loss 1.7985
Epoch 1 Batch 300 Loss 1.6724
Epoch 1 Loss 2.0235
Time taken for 1 epoch 149.3063322815 sec
    ...训练过程略...
    
Input: <start> hace mucho frio aqui . <end>
Predicted translation: it s very cold here . <end> 
Input: <start> esta es mi vida . <end>
Predicted translation: this is my life . <end> 
Input: <start> ¿ todavia estan en casa ? <end>
Predicted translation: are you still at home ? <end> 
Input: <start> trata de averiguarlo . <end>
Predicted translation: try to figure it out . <end> 

以后如果只是想测试翻译效果,可以不带train参数执行,直接看翻译结果。
对于每一个翻译句子,程序都会绘制注意力矩阵图:

通常语法不是很复杂的句子,基本是顺序对应关系,所以注意力亮点基本落在对角线上。
图中X坐标是西班牙文单词,Y坐标是英文单词。每个英文单词,沿X轴看,亮点对应的X轴单词,表示对于翻译出这个英文单词,是哪一个西班牙文单词权重最大。

(待续...)


版权声明:本文为[俺踏月色而来]原创文章
转载请带上:http://copyfuture.com/blogs-details/b7caf95bdfff491ad23783b886e69847
或:https://www.cnblogs.com/andrewwang/p/10794420.html


  1. 关于坚持和发展中国特色社会主义的几个问题
  2. 7种意式西装穿搭法,散发出优雅男人味!
  3. 日本反对无效!俄军500士兵70辆坦克登岛,北海道在导弹射程内
  4. 乐视控股:韬藴至今未支付收购易到资金,将提起诉讼
  5. 海底捞员工“神”服务背后的“核武器“
  6. Docker 下载镜像
  7. Spotify痛斥苹果:就是个垄断头子
  8. 《机器学习实战》-机器学习基础
  9. 一个曾经带来感动的时代:你还记得你玩过的家用主机吗?
  10. nginx+uwsgi+flask+supervisor 项目部署
  11. javascript ES6 新特性之 let
  12. 喝了一杯果汁,测餐后2小时血糖8.6算高吗?高到多少就算升高了?
  13. 物联网卡在智能家居之中,主要有什么用?
  14. 第3个1亿台冰箱下线,海尔将“世界第一”进化到未来
  15. 只比霸道贵3万,最便宜的陆巡来了,低配53万起,配手动变速箱
  16. xshell连接到服务器代理上网
  17. css炫酷动画收藏
  18. 库里复出26分,勇士力克活塞返西部第一!汤神24分,追梦14+8
  19. 桑切斯效应?德赫亚后,曼联又有两主力索要高薪,不给就离队!
  20. 科学家将液态金属转化为等离子体!
  21. 一分钟充好电是什么概念?菲斯克声称将造1分钟充好电的电动汽车
  22. 每年追《权力的游戏》,都提前一个月做好挨虐准备
  23. 荐读|水不试,不知深浅;人不交,不知好坏。
  24. 签到抽奖功能——常见前端抽奖需求
  25. 人工智能时代:哪些行业迎来重大变革?
  26. 京东财报电话会议实录:刘强东称今年会更关注三四线城市
  27. LOL:不进季后赛的LGD才是真正的LGD!2:1击败OMG让出垫底的位置
  28. 《强奸日》引起众怒 苏格兰议员强烈谴责黄色暴力游戏
  29. 如何写工程代码——重新认识面向对象
  30. Java集合--TreeSet详细解析
  31. 为啥买车的人越来越少?知道原因后,你还敢买车吗
  32. java jdk动态代理模式举例浅析
  33. 【计算机网络】TCP/IP若干问题
  34. RSA2019创新沙盒 | Duality: 基于同态加密的数据分析和隐私保护方案
  35. 核潜艇通讯:决定全球海洋归属权的暗网系统
  36. 迪丽热巴已成立自己的工作室,看到名字后,网友:杨幂没白疼你
  37. 霉变食物不可惜,当心黄曲霉素成肝之殇!
  38. 货架半空四个月,全时便利店终于找到“接盘侠”
  39. 营养师提醒您:糖友别吃汤泡饭
  40. [Swift]LeetCode606. 根据二叉树创建字符串 | Construct String from Binary Tree
  41. 科学家在石墨中观察到“第二类声”,神奇特性再加一?
  42. 揭秘顺治皇帝壮年而毙的背后隐情
  43. 大连工业制造领域开始应用5G技术
  44. 在线制作数据库ER模型
  45. [Cake] 2. dotnet 全局工具 cake
  46. 中国无人机手起刀落,百万美元T72直接被打爆,沙特说“买值了”
  47. “高价回收驾照分”?那驾照分值不值钱,能不能进行买卖?
  48. 爆款20万元日系标杆SUV 动力媲美同级别轿跑
  49. Kubernetes集群搭建之Master配置篇
  50. 写得太好了!做餐饮不是一般人能干的,只要干好的都不是一般人!
  51. 歌名是晴天歌词是雨天,多少年后才懂:周杰伦《晴天》
  52. 宇宙起源:宇宙大爆炸(1)
  53. 联想上演碰瓷闹剧, PPT发布会被网友集体吐槽
  54. 10部结局极具争议的电影,直至今日,人们仍在争论
  55. Android 四大组件之broadcast的理解
  56. Elastic Search 安装和配置
  57. 12点聊电商:淘宝规范商务等行业 对严重违规者立即清退
  58. mysql实现主从备份
  59. “村与村的战斗”?日本战国时代的战争规模真的那么小吗?
  60. 三个关键点,确保以普惠多赢理念助推“一带一路”
  61. 吴京新片正式开拍,继成龙章子怡加盟后,如今又来了一位人气男星
  62. PYPL 二月榜单发布:最受欢迎的编程语言、IDE 和数据库都是哪些
  63. 评论:一个豆瓣差评没什么大不了
  64. 地球是唯一有表面液态水的星球,为啥会有水存在?没水会有啥变化
  65. Python爬虫入门教程 38-100 教育部高校名单数据爬虫 scrapy
  66. #Java干货分享:这五个网站能打通你的任督二脉,让你技术大增
  67. 一个亡国之君,那些击败、戏弄他的人,都被“他的名字”熬死了
  68. 解决下一个500亿快递!未来的千亿快递,需要智慧物流骨干网!
  69. 没想到吧,《绿皮书》的台词居然还能这样读?
  70. 外媒曝《复联4》主演片酬,钢铁侠赚5亿,黑寡妇1.3亿
  71. iate id generator错误
  72. 十年老策划把他想对萌新说的话做成了一款游戏
  73. 比亚迪F3终将被取代,新车型内外更漂亮,1.5L自吸5万将开卖!
  74. 拼多多发布2018扶贫助农年报,农产品销售653亿同比增233%
  75. 啥都涨,粮价不涨?该如何确保粮食安全?
  76. MATLAB 音响系统工具箱
  77. 一个新手程序员 2019 的九大尴尬瞬间
  78. 一季度31省消费支出榜出炉!10省花钱比赚钱“能干”,广东“超车”江苏,辽宁黑龙江名次跌(多图)
  79. 与河南巡抚田文镜斗法,直隶总督李绂却为何输了?
  80. 基于CAS实现单点登录(二)
  81. 魏建军们用实力“打脸”不懂中国汽车工业的人
  82. 前端知识分享.md
  83. 转载,汉语世界上最先进的语言(来自几年前的转发,如今重新转发)
  84. 硝苯地平缓释片、硝苯地平控释片,哪一个好?告诉你答案
  85. Sign Up Account In CloudAMQP
  86. 湖人绿军在总决赛相遇13次,湖人却只赢了4次?只因一人遭8连亚军
  87. 【品金庸】武馋仁“三绝”——洪七公
  88. 张庭背小三骂名委屈20年?林瑞阳前妻疑怒怼:该感谢我静默21年
  89. 揭开趣头条的“土味”流量生意经
  90. 华为有哪些“备胎”?这里有一份中国芯片企业权威榜单
  91. 华为高通就专利和解谈判:或每年支付高通超5亿美元
  92. 程序员:如何正确使用你的黄金时间
  93. python3入门教程之基本数据类型(一)
  94. 这十个人的背后,是2018年游戏业的悲欢离合
  95. "年轻人离开工厂送外卖"不应被误读,也不该有偏见
  96. 刀塔自走棋:野兽流不完全攻略
  97. 无状态点赞 王思聪骂Uzi,直播间惨遭爆破!官博秒回应!
  98. 齐王司马冏为何会兵败身亡?绝不是史书所说的“贪图享乐”
  99. 有些句子真的很美,这就是语言文字的魅力
  100. “团贷网”案:实控人近9亿转移隐匿资金被追缴冻结

  1. Python开发:部分第三方库无法在线安装解决方法(947)
  2. [Swift]LeetCode325. 最大子数组之和为k $ Maximum Size Subarray Sum Equals k(779)
  3. Matlab 2019a 安装包下载以及安装和激活(717)
  4. 仅限Edge和Chrome访问 全新网页端Skype应用上线(685)
  5. 前端笔记之NodeJS(一)初识NodeJS&内置模块&特点(682)
  6. C#读取excel文件提示未在本地计算机上注册“Microsoft.ACE.OLEDB.12.0”提供程序(663)
  7. 【预警通告】Weblogic反序列化远程代码执行漏洞(640)
  8. Visual Studio 2019 正式发布,重磅更新,支持live share(599)
  9. 【预警通告】Apache Tomcat远程代码执行漏洞CVE-2019-0232(573)
  10. 网上赌博平台维护审核提不了款怎么办?(535)
  11. React 与 React-Native 使用同一个 meteor 后台(525)
  12. Sublime Text3 最新版3207 安装及破解(458)
  13. Visual Studio 2019 正式发布(389)
  14. [翻译] Visual Studio 2019: 极速编码. 智能工作. 创造未来.(388)
  15. 刘强东身边的CXO还有谁“幸存”(373)
  16. 舍命生子产妇吴梦丈夫怒斥:没抢肺源不是精神分裂,网友断章取义(371)
  17. Confluence SSRF及远程代码执行漏洞处置手册(370)
  18. 机器学习 ML.NET 发布 1.0 RC(369)
  19. 阿里巴巴2018年纳税516亿元 同比增40%(368)
  20. K8s集群安装--最新版 Kubernetes 1.14.1(336)
  21. 雷军清华演讲实录:小米9年的创新、变革与未来(332)
  22. 小米手机卖不动了?(327)
  23. F#周报2019年第14期(317)
  24. 积分一样却选手下败将出战国际赛,《最强大脑》云队选手被坑了?(300)
  25. 《最强大脑》要垮?桑洁魏坤琳出轨细节被扒,戚薇才是神助攻(257)
  26. 华电教授孙玉兵被指与昔日同学共同学术造假,多所高校调查(250)
  27. F#周报2019年第15期(249)
  28. 日本明仁天皇退位,日本“平成”年代结束(246)
  29. linux系统安装cdcfordb2udb(241)
  30. Oracle甲骨文大规模裁员,你背离时代就会被淘汰(240)
  31. 他联系叙恐怖分子“卖军火”,称能搞到2000枚导弹,关键时刻中国警察出手(237)
  32. 魔兽世界:8.15搏击俱乐部坐骑获取流程 鳄鱼布鲁斯坐骑(236)
  33. 针对django2.2报错:UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 9737: ill....(235)
  34. 女友被曝插足许志安郑秀文婚姻 知情人透露马国明已下定决心分手(232)
  35. 视觉中国深夜道歉:全面配合监管部门彻底积极整改(228)
  36. 为什么国内汽车用沥青阻尼片,而欧洲主机厂却用树脂?(228)
  37. SQL简介及MySQL的安装目录详解(227)
  38. 谁是苏小明饭局爆粗偷拍者?知情人称另有其人(223)
  39. NodeJs之邮件(email)发送(222)
  40. 迪玛希好惨!昨晚《歌手》为声入人心男团帮帮唱,却再被指控侵权(222)
  41. 市值暴跌90%,世界零售巨头申请破产战胜了所有对手却输时代(221)
  42. 函数防抖,与函数节流(219)
  43. 机器学习基石笔记:01 The Learning Problem(217)
  44. 深度学习python的配置(Windows)(215)
  45. [深度应用]·实战掌握Dlib人脸识别开发教程(213)
  46. 许志安出轨视频系蓄谋偷拍?司机被曝收40万装红外摄像头(210)
  47. Google AI 系统 DeepMind 高中数学考试不及格(210)
  48. 干货!21部漫威电影观影顺序指南,在《复联4》之前赶紧补齐!(208)
  49. 赌命生子九个月后,吴梦离世:前半辈子任性了,我用生命买单(206)
  50. 威廉王子出轨凯特王妃闺蜜? 外媒称婚外情致兄弟反目(203)
  51. spring-cloud-sleuth+zipkin源码探究(203)
  52. WebGL three.js学习笔记 纹理贴图模拟太阳系运转(201)
  53. 新更新kb4493472导致无法正常开机(195)
  54. 杜敬谦死因疑曝光!或因他这一特殊的训练方式,泳迷高呼孙杨退役(190)
  55. 韦杰落网,金诚集团终局(188)
  56. AntDesign Form表单字段校验的三种方式(188)
  57. 华为推出方舟编译器 称可提升安卓系统效率(185)
  58. 山东庆云民企3000亩土地被贱卖 国企接盘拟转性(184)
  59. 《权力的游戏》龙妈有那么多爱她的人,为什么最终会选择琼恩雪诺(182)
  60. 张无忌为什么爱上她?陈钰琪版赵敏终于给答案了(181)
  61. “国防”靠美国? 韩国瑜=马英九2.0? 走着瞧(180)
  62. Python破解Wifi密码思路(180)
  63. 直认与老公感情淡了!27岁TVB上位女星:我们不是好熟(179)
  64. CUBA Studio 8.0 发布,企业级应用开发平台(179)
  65. 张丹峰出轨最新锤来了!毕滢的朋友圈简直刷新下限啊!(177)
  66. Github 上 Star 最多的个人 Spring Boot 开源学习项目(176)
  67. 使用 C 语言实现一个 HTTP GET 连接(175)
  68. 拿着普通员工超300倍的工资裁员800人,这家游戏公司CEO引发员工不满|一周新闻(175)
  69. AntD框架的upload组件上传图片时遇到的一些坑(175)
  70. 币安称 4000 万美元比特币被盗(174)
  71. 不要996!程序员创建955.WLB不加班公司名单,GitHub周榜第二(174)
  72. Weblogic CVE-2019-2647等相关XXE漏洞分析(173)
  73. Codejam Qualification Round 2019(173)
  74. simulink创建简单模型(172)
  75. 《跃迁-成为高手的技术》之联机学习(171)
  76. python爬虫重定向次数过多问题(171)
  77. [NewLife.XCode]高级查询(168)
  78. 强大的jQGrid的傻瓜式使用方法。以及一些注意事项,备有相应的引入文件。(167)
  79. 核心算法缺位,人工智能发展面临“卡脖子”窘境(165)
  80. Algolia使用教程 , 超详细傻子看都会(165)
  81. Delphi 开发微信公众平台 (二) 用户管理(164)
  82. 只需知道电话号码 即可监控任意一部手机,获取位置,太可怕(163)
  83. 范斯晶对祖母的称呼很意外,范志毅很心疼,缺少母爱的孩子不容易(162)
  84. 如何定位前端线上问题(如何排查前端生产问题)(162)
  85. 告诉你去越南芽庄必带回的好东西(161)
  86. 数学家发现完美的乘法(160)
  87. 百度网盘下载神器 PanDownload v2.0.9(破解版、不限速)(159)
  88. 双双出轨!许志安劈腿马国明港姐女友,二人被拍16分钟激吻超20次(159)
  89. 高管被警方带走背后:巧达科技操盘2亿人简历生意(159)
  90. 定义工作,解读自我——IT帮2019年2月线下活动回顾(159)
  91. 吹爆惠英红,《铁探》这位霸道总警司超带感!真乃港剧罕见大女主(156)
  92. vue生成图片验证码(155)
  93. 三国正史第一猛将:一人单挑数千人,不是吕布也不是关羽(154)
  94. 从0到1上线一个微信小程序(154)
  95. FreeSql 如何现实 Sqlite 跨库查询(154)
  96. 向佐的弟弟叫向佑,网友:那郭碧婷生的孩子叫什么?(153)
  97. spring-boot-2.0.3不一样系列之源码篇 - pageHelper分页,绝对有值得你看的地方(153)
  98. [Node.js] 3、搭建hexo博客(152)
  99. java基础(十五)----- Java 最全异常详解 ——Java高级开发必须懂的(152)
  100. TensorFlow从1到2(十)带注意力机制的神经网络机器翻译(151)

  1. 大数据技术之_24_电影推荐系统项目_07_工具环境搭建(具体实操)
  2. 前端限制显示的文本字数的几种方法——不换行与换行
  3. 【实验吧】该题不简单——writeup
  4. 反向传播算法
  5. =、==、===、equals()的区别
  6. GitHub 推出开发者赞助项目
  7. 剑指Offer的学习笔记(C#篇)-- 从上往下打印二叉树
  8. argparse 在深度学习中的应用
  9. 张云雷复出?西城区文旅局:德云社在辖区内演出未发现违法违规问题
  10. Maven安装与配置
  11. acWing 825. 排队购物
  12. ajax&&jquery
  13. 苏联攻击机的悲壮行动,明知德军战机拦截,仍在无护航状态下出击
  14. 华为的5G技术,源于这种数学方法
  15. 一站式自动化测试平台 http://www.Autotestplat.com
  16. RabbitMQ总结
  17. 第九组 通信3班 063 自反ACL
  18. 短线还有最后一跌?大V们表示:反弹近了!(5月23日)
  19. 第九组 通信3班 063 OSPFv2与OSPFv3综合实验
  20. C# IE选项 - 重置IE
  21. Spring_数据校验和自定义检验规则
  22. 谈谈Java的string为什么是不可变的
  23. OFFICE 365 A1 Plus账号注册
  24. 初学python—做一个数组的增删改查操作
  25. oc工程中oc、swift混编代码打包成静态framework踩坑笔记
  26. 阿里云推“智能秒停系统”:50秒内短信通知 再不怕吃罚单
  27. 今天购买了一个云服务器
  28. 神奇!乌鸦竟然会传达悲观和怀疑情绪 还会对同伴“冷嘲热讽”
  29. 数字IC设计入门必备——VIM自定义模板调用与VCS基本仿真操作示例
  30. 点击事件的坐标计算(client || offset) +(X || Width || Left) 各种排列组合别绕晕
  31. windows下dubbo-admin2.6.x之后版本的安装
  32. linux 之基本命令学习总结
  33. 传祺难续“传奇”?销量暴跌超4成 加价卖车被“断裂门”尽毁
  34. 香港豪门后宫持续曝光:他用选美比赛“选妃”,与几万女星交往
  35. 小窥React360——用React创建360全景VR体验
  36. Spring Boot 2 快速教程:WebFlux 集成 Mongodb(四)
  37. .Net Core下使用RabbitMQ比较完备的两种方案(虽然代码有点惨淡,不过我会完善)
  38. “80后”女博士已任团中央书记处书记
  39. 杨元庆:现在是联想的最好时刻 我们四大战役全部打了胜仗
  40. Java开发环境的搭建(JDK和Eclipse的安装)
  41. oracle学习笔记(十四) 数据库对象 索引 视图 序列 同义词
  42. 机构风向标:外资出逃超500亿 美的集团等白马股表现欠佳
  43. 跟踪记录ABAP对外部系统的RFC通信
  44. c++11多线程详解(一)
  45. 小蓝杯,跌破发行价了
  46. [NewLife.XCode]百亿级性能
  47. 33岁何洁商场走穴被曝光,路人镜头下的她与精修图差别好大
  48. 途牛第一季度净亏损2240万美元 同比亏损幅度扩大
  49. 00 | Two Sum
  50. 智能威胁分析之图数据构建
  51. 快速掌握RabbitMQ(二)——四种Exchange介绍及代码演示
  52. Neo4j 第六篇:Cypher语法
  53. Java微信公众平台开发(三)--接收消息的分类及实体的创建
  54. Java8 中的 Optional
  55. 如何显示超大图像(3)
  56. 贵州检察机关依法对袁仁国决定逮捕
  57. 有关xerospolit运行报错问题的有效解决方案
  58. ADO学途 one day
  59. Linux 中 ip netns 命令
  60. Python爬虫之设置selenium webdriver等待
  61. BSOJ1040 -- 【练习题目】美元DOLLARS
  62. 外媒:稀土是中国手中的一张王牌
  63. sql server添加sa用户和密码
  64. 深入理解JVM的类加载
  65. querySelector和getElementById之间的区别
  66. 简说设计模式——观察者模式
  67. 扰动函数和拉链法模拟HashMap的存储结构
  68. 东芝中国:“上海东芝公司”不存在 未停止与华为的合作
  69. 彭于晏马思纯主演张爱玲这部小说,却被说更适合演《骆驼祥子》?
  70. 停止向华为供货?东芝辟谣回应
  71. 云米第一季度净利润5310万元 同比增长68%
  72. 贪吃的古蛙,古生物学家发现亿年前两栖动物之间战争
  73. APICloud发布低代码开发平台 效率提升30%至60%
  74. Golang 读写锁RWMutex 互斥锁Mutex 源码详解
  75. shell初级-----数据呈现方式
  76. 白玉兰入围名单公布!《知否》《都挺好》上榜,还有这部豆瓣3分剧
  77. 深网 | 京东618接入快手、抖音 实现“即看即买”
  78. 解决 APP启动白屏黑屏问题
  79. Spring Cloud Hystrix理解与实践(一):搭建简单监控集群
  80. 浏览器与服务器通信技术——jsonp
  81. 【刷题笔记】LeetCode 606. Construct String from Binary Tree
  82. 央行副行长刘国强:应对汇率波动经验丰富,政策工具储备充足
  83. 部署Azure Log Analytics
  84. 计算机基础--http的基础整理和巩固
  85. 章子怡:女人四十,不止表面风光
  86. 直击|对话杨元庆:希望今年创最好盈利 要震慑住谣言
  87. 微软通过合作为美国270万农村退伍军人提供高速宽带服务
  88. Java进程占用内存过高,排查解决方法
  89. Go语言中使用切片(slice)实现一个Vector容器
  90. 商务部回应美宣布对13个中国企业或个人实施制裁:反对“长臂管辖”
  91. 优酷土豆的Redis服务平台化之路
  92. shell初级-----处理用户输入
  93. 感受lambada之美,推荐收藏,需要时查阅
  94. 美团点评发布2019年第一季度财报,营收192亿元超预期
  95. 任正非:Arm暂停合作对华为没影响
  96. redis和memcached的区别(总结)
  97. Spring Cloud与Duddo比较
  98. File类
  99. 朝鲜最强智能手机!人脸识别、无线充电、画质感人还支持无线耳机
  100. “断供”传闻屡遭反转,谁在制造恐慌?