image
VincentWei

天地间,浩然正气长存,为天地立心,为生民立命,为往圣继绝学,为万世开太平!

seq2seq+attention+beamsearch构建闲聊对话

VincentWei    2019-07-05 16:52

首先介绍大家熟悉的Seq2Seq模型。



接下来介绍Beam Search。预测的时候使用,训练的时候不使用。



引入Attention的初衷,越长的距离关系,LSTM的能力在下降。


核心代码如下:(注释的很清楚了)
"""
    Sequence To Sequence模型
    定义了模型编码器、解码器、优化器、训练、预测
"""

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell, GRUCell, MultiRNNCell, LSTMStateTuple, DropoutWrapper, ResidualWrapper
from tensorflow.contrib.seq2seq import BahdanauAttention, AttentionWrapper,TrainingHelper,\
        BasicDecoder, BeamSearchDecoder
from tensorflow import layers
from data_unit import DataUnit
from tensorflow.python.ops import array_ops
import os

class Seq2Seq(object):

    def __init__(self, hidden_size, cell_type,
                 layer_size, batch_size,
                 encoder_vocab_size, decoder_vocab_size,
                 embedding_dim, share_embedding,
                 max_decode_step,max_gradient_norm,
                 learning_rate,decay_step,
                 min_learning_rate,bidirection,
                 beam_width,
                 mode
                 ):
        """
        初始化函数,参数意义可查看config.py中对应的注释
        """
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.layer_size = layer_size
        self.batch_size = batch_size
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.embedding_dim = embedding_dim
        self.share_embedding = share_embedding
        self.max_decode_step = max_decode_step
        self.max_gradient_norm = max_gradient_norm
        self.learning_rate = learning_rate
        self.decay_step = decay_step
        self.min_learning_rate = min_learning_rate
        self.bidirection = bidirection
        self.beam_width = beam_width
        self.mode = mode
        self.global_step = tf.Variable(0, trainable=False, name='global_step')
        self.build_model()

    def build_model(self):
        """
        构建完整的模型
        :return:
        """
        self.init_placeholder()
        self.embedding()
        encoder_outputs, encoder_state = self.build_encoder()
        self.build_decoder(encoder_outputs, encoder_state)
        if self.mode == 'train':
            self.build_optimizer()
        self.saver = tf.train.Saver()

    def init_placeholder(self):
        """
        定义各个place_holder
        :return:
        """
        self.encoder_inputs = tf.placeholder(tf.int32, shape=[self.batch_size,None], name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(tf.int32, shape=[self.batch_size,], name='encoder_inputs_length')
        self.keep_prob = tf.placeholder(tf.float32, shape=(), name='keep_prob')
        if self.mode == 'train':
            self.decoder_inputs = tf.placeholder(tf.int32, shape=[self.batch_size,None], name='decoder_inputs')
            self.decoder_inputs_length = tf.placeholder(tf.int32, shape=[self.batch_size, ],
                                                        name='decoder_inputs_length')
            self.decoder_start_token = tf.ones(shape=(self.batch_size, 1),dtype=tf.int32) * DataUnit.START_INDEX
            self.decoder_inputs_train = tf.concat([self.decoder_start_token,self.decoder_inputs], axis=1)

    def embedding(self):
        """
        词嵌入操作
        :param share:编码器和解码器是否共用embedding
        :return:
        """
        with tf.variable_scope('embedding'):
            encoder_embedding = tf.Variable(tf.truncated_normal(shape=[self.encoder_vocab_size, self.embedding_dim], stddev=0.1),
                                            name='encoder_embeddings')
            if not self.share_embedding:
                decoder_embedding = tf.Variable(tf.truncated_normal(shape=[self.decoder_vocab_size, self.embedding_dim], stddev=0.1),
                                            name='decoder_embeddings')
                self.encoder_embeddings = encoder_embedding
                self.decoder_embeddings = decoder_embedding
            else:
                self.encoder_embeddings = encoder_embedding
                self.decoder_embeddings = encoder_embedding

    def one_cell(self, hidden_size, cell_type):
        """
        一个神经元
        :return:
        """
        if cell_type == 'gru':
            c = GRUCell
        else:
            c = LSTMCell
        cell = c(hidden_size)
        cell = DropoutWrapper(
            cell,
            dtype=tf.float32,
            output_keep_prob=self.keep_prob,
        )
        cell = ResidualWrapper(cell)
        return cell

    def build_encoder_cell(self, hidden_size, cell_type, layer_size):
        """
        构建编码器所有层
        :param hidden_size:
        :param cell_type:
        :param layer_size:
        :return:
        """
        # 此处一定不能用[LSTMCell] * layer_size,否则后续会报错
        cells = [self.one_cell(hidden_size, cell_type) for _ in range(layer_size)]
        return MultiRNNCell(cells)

    def build_encoder(self):
        """
        构建完整编码器
        :return:
        """
        with tf.variable_scope('encoder'):
            encoder_cell = self.build_encoder_cell(self.hidden_size, self.cell_type, self.layer_size)
            encoder_inputs_embedded = tf.nn.embedding_lookup(self.encoder_embeddings, self.encoder_inputs)
            encoder_inputs_embedded = layers.dense(encoder_inputs_embedded,
                                                        self.hidden_size,
                                                        use_bias=False,
                                                        name='encoder_residual_projection')
            initial_state = encoder_cell.zero_state(self.batch_size, dtype=tf.float32)
            if self.bidirection:
                encoder_cell_bw = self.build_encoder_cell(self.hidden_size, self.cell_type, self.layer_size)
                (
                    (encoder_fw_outputs, encoder_bw_outputs),
                    (encoder_fw_state, encoder_bw_state)
                ) = tf.nn.bidirectional_dynamic_rnn(
                    cell_bw=encoder_cell_bw,
                    cell_fw=encoder_cell,
                    inputs=encoder_inputs_embedded,
                    sequence_length=self.encoder_inputs_length,
                    dtype=tf.float32,
                    swap_memory=True)

                encoder_outputs = tf.concat(
                    (encoder_bw_outputs, encoder_fw_outputs), 2)
                encoder_final_state = []
                for i in range(self.layer_size):
                    c_fw, h_fw = encoder_fw_state[i]
                    c_bw, h_bw = encoder_bw_state[i]
                    c = tf.concat((c_fw, c_bw), axis=-1)
                    h = tf.concat((h_fw, h_bw), axis=-1)
                    encoder_final_state.append(LSTMStateTuple(c = c, h = h))
                encoder_final_state = tuple(encoder_final_state)
            else:
                encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
                    cell=encoder_cell,
                    inputs=encoder_inputs_embedded,
                    sequence_length=self.encoder_inputs_length,
                    dtype=tf.float32,
                    initial_state=initial_state,
                    swap_memory=True)

            return encoder_outputs, encoder_final_state

    def build_decoder_cell(self, encoder_outputs, encoder_final_state,
                            hidden_size, cell_type, layer_size):
        """
        构建解码器所有层
        :param encoder_outputs:
        :param encoder_state:
        :param hidden_size:
        :param cell_type:
        :param layer_size:
        :return:
        """
        sequence_length = self.encoder_inputs_length
        if self.mode == 'decode':
            encoder_outputs = tf.contrib.seq2seq.tile_batch(
                encoder_outputs, multiplier=self.beam_width)
            encoder_final_state = tf.contrib.seq2seq.tile_batch(
                encoder_final_state, multiplier=self.beam_width)
            sequence_length = tf.contrib.seq2seq.tile_batch(
                sequence_length, multiplier=self.beam_width)

        if self.bidirection:
            cell = MultiRNNCell([self.one_cell(hidden_size * 2, cell_type) for _ in range(layer_size)])
        else:
            cell = MultiRNNCell([self.one_cell(hidden_size, cell_type) for _ in range(layer_size)])
        # 使用attention机制
        self.attention_mechanism = BahdanauAttention(
            num_units=self.hidden_size,
            memory=encoder_outputs,
            memory_sequence_length=sequence_length
        )
        def cell_input_fn(inputs, attention):
            mul = 2 if self.bidirection else 1
            attn_projection = layers.Dense(self.hidden_size * mul,
                                           dtype=tf.float32,
                                           use_bias=False,
                                           name='attention_cell_input_fn')
            return attn_projection(array_ops.concat([inputs, attention], -1))
        cell = AttentionWrapper(
            cell=cell,
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_size,
            cell_input_fn=cell_input_fn,
            name='Attention_Wrapper'
        )
        if self.mode == 'decode':
            decoder_initial_state = cell.zero_state(batch_size=self.batch_size * self.beam_width, dtype=tf.float32).clone(
                cell_state=encoder_final_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size=self.batch_size,
                                                    dtype=tf.float32).clone(
                cell_state=encoder_final_state)
        return cell, decoder_initial_state

    def build_decoder(self, encoder_outputs, encoder_final_state):
        """
        构建完整解码器
        :return:
        """
        with tf.variable_scope("decode"):
            decoder_cell, decoder_initial_state= self.build_decoder_cell(encoder_outputs, encoder_final_state,
                                                                         self.hidden_size, self.cell_type,
                                                                         self.layer_size)
            #输出层投影
            decoder_output_projection = layers.Dense(self.decoder_vocab_size,dtype=tf.float32,
                                                          use_bias=False,
                                                          kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1),
                                                          name='decoder_output_projection')
            if self.mode == 'train':
                # 训练模式
                decoder_inputs_embdedded = tf.nn.embedding_lookup(self.decoder_embeddings,self.decoder_inputs_train)
                training_helper = TrainingHelper(
                    inputs=decoder_inputs_embdedded,
                    sequence_length=self.decoder_inputs_length,
                    name='training_helper'
                )
                training_decoder = BasicDecoder(decoder_cell, training_helper,
                                                                   decoder_initial_state, decoder_output_projection)
                max_decoder_length = tf.reduce_max(self.decoder_inputs_length)
                training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                                  maximum_iterations=max_decoder_length)
                self.masks = tf.sequence_mask(self.decoder_inputs_length, maxlen = max_decoder_length, dtype=tf.float32, name='masks')
                self.loss = tf.contrib.seq2seq.sequence_loss(logits = training_decoder_output.rnn_output,
                                                             targets = self.decoder_inputs,
                                                             weights=self.masks,
                                                             average_across_timesteps=True,
                                                             average_across_batch=True
                                                             )
            else:
                # 预测模式
                start_token = [DataUnit.START_INDEX] * self.batch_size
                end_token = DataUnit.END_INDEX
                inference_decoder = BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=lambda x : tf.nn.embedding_lookup(self.decoder_embeddings, x),
                        start_tokens=start_token,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=self.beam_width,
                        output_layer=decoder_output_projection
                    )
                inference_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(inference_decoder,
                                                                                  maximum_iterations=self.max_decode_step)
                self.decoder_pred_decode = inference_decoder_output.predicted_ids
                self.decoder_pred_decode = tf.transpose(
                    self.decoder_pred_decode,
                    perm=[0, 2, 1]
                )

    def check_feeds(self, encoder_inputs, encoder_inputs_length,
                    decoder_inputs, decoder_inputs_length, keep_prob, decode):
        """
            检查输入,返回输入字典
        """
        input_batch_size = encoder_inputs.shape[0]
        assert input_batch_size == encoder_inputs_length.shape[0],'encoder_inputs 和 encoder_inputs_length的第一个维度必须一致'
        if not decode:
            target_batch_size = decoder_inputs.shape[0]
            assert target_batch_size == input_batch_size,'encoder_inputs 和 decoder_inputs的第一个维度必须一致'
            assert target_batch_size == decoder_inputs_length.shape[0],'decoder_inputs 和 decoder_inputs_length的第一个维度必须一致'
        input_feed = {self.encoder_inputs.name: encoder_inputs,
                      self.encoder_inputs_length.name: encoder_inputs_length}
        input_feed[self.keep_prob.name] = keep_prob
        if not decode:
            input_feed[self.decoder_inputs.name] = decoder_inputs
            input_feed[self.decoder_inputs_length.name] = decoder_inputs_length
        return input_feed

    def build_optimizer(self):
        """
        构建优化器
        :return:
        """
        learning_rate = tf.train.polynomial_decay(self.learning_rate,self.global_step,
                                        self.decay_step,self.min_learning_rate,power=0.5)
        self.current_learning_rate = learning_rate
        trainable_params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, trainable_params)
        # 优化器
        self.opt = tf.train.AdamOptimizer(
            learning_rate=learning_rate
        )
        # 梯度裁剪
        clip_gradients, _ = tf.clip_by_global_norm(
            gradients, self.max_gradient_norm
        )
        # 更新梯度
        self.update = self.opt.apply_gradients(
            zip(clip_gradients, trainable_params),
            global_step=self.global_step
        )

    def train(self, sess, encoder_inputs, encoder_inputs_length,
              decoder_inputs, decoder_inputs_length, keep_prob):
        """
        训练模型
        :param sess:
        :return:
        """
        input_feed = self.check_feeds(encoder_inputs, encoder_inputs_length,
                                      decoder_inputs, decoder_inputs_length,keep_prob,
                                      False)
        output_feed = [
            self.update, self.loss,
            self.current_learning_rate
        ]
        _, cost, lr = sess.run(output_feed, input_feed)
        return cost, lr

    def predict(self, sess, encoder_inputs, encoder_inputs_length):
        """
        预测
        :return:
        """
        input_feed = self.check_feeds(encoder_inputs, encoder_inputs_length,
                                      None, None, 1, True)
        pred = sess.run(self.decoder_pred_decode, input_feed)
        return pred[0]


    def save(self, sess, save_path='model.ckpt'):
        """
        保存模型
        :return:
        """
        self.saver.save(sess, save_path=save_path)

    def load(self, sess, save_path='model/chatbot_model.ckpt'):
        """
        加载模型
        """
        self.saver.restore(sess, save_path)
 
Last Modified: 2019-07-05 16:55
Views: 2.0K

[[total]] comments

Post your comment
  1. [[item.time]]
    [[item.user.username]] [[item.floor]]Floor
  2. Click to load more...
  3. Post your comment