900字范文,内容丰富有趣,生活中的好帮手!
900字范文 > transformer做文本分类的keras实现完整版

transformer做文本分类的keras实现完整版

时间:2020-05-17 00:43:50

相关推荐

transformer做文本分类的keras实现完整版

背景

目前csdn上搜索到的keras的版本实现,排在前面的是:

/xiaosongshine/article/details/86595847

但是,这个文章存在挺多问题。本身这个文章的实现其实是少了一部分的(缺少了LayerNorm+残差的部分),multi-head attention的实现也少了一个 W o W_o Wo​再做一次全连接映射。加上其本身运用的参数跟原始论文也差很多,所以跟论文描述的encoder区块其实对应不太上,如果是想对着论文来看代码的话,这段代码可能会产生一定的误导。所以我从各个地方找了其他的缺少的部分实现,凑出一个基本能对应上论文的keras版本的transformer-encoder完整的实现;另一方面,也顺便结合原理和代码(会尽量把注释写清楚),将transformer的原理重新复习一遍。

keras的版本

为了兼容csdn上看到的代码,keras的版本采用的是2.2.4的keras版本(非tf.kreas)。如果需要其他更高阶版本或者tf.keras的版本,可能会需要有一定的改动,可以参考GitHub上的CyberZHG的代码进行改动即可。

主要参考链接

原理主要参考链接:

/p/44121378/p/44731789/u012526436/article/details/86295971原始论文 /pdf/1706.03762.pdf

代码主要参考链接:

/CyberZHG/keras-transformer/xiaosongshine/article/details/86595847/qq_40742298/article/details/115011147

模型整体结构

因为是用来做文本分类,所以这个图里面我们只谈左边的encoder部分。

encoder部分首先是input + embedding部分,其次是由N个block组成的编码部分,在原文中,这个N是6。每个block呢,又由multi-head attention、add & norm 、feedforward和残连接层组成,我们接下来还是一步一步的拆解。

Input层

原始的Input层,为词向量+position embedding,这个跟一般的文本输入一样,假设输入为(batch_size, seq_len, embedding_size),注意一点的是,这个embedding_size为了在后续可以接上残差连接层,其应该要在整个网络中保证一致,原文中,这个embedding_size和各种子层的维度要一致,原文都是512维,以 d m o d e l = 512 d_{model}=512 dmodel​=512表示。

Position embedding层

因为transformer与RNN不同,其没有了词位置顺序信息,因此为了保证位置信息,先将词过一个position embedding,然后再与词向量求和作为后续block的输入。注意一点的是,《Attention is all you need》原文提到了用sin和cos的方式以及训练词位置的embedding,经过实验发现二者没有区别,最后用的是sin和cos的方式。但是bert里面的position embedding是可训练的。

公式不赘述,大致表示如下:

具体的代码实现及注释见下:

#! -*- coding: utf-8 -*-#%%from __future__ import print_functionfrom keras import backend as Kfrom keras.engine.topology import Layerclass Position_Embedding(Layer):def __init__(self, size=None, mode='sum', **kwargs):self.size = size #必须为偶数self.mode = modesuper(Position_Embedding, self).__init__(**kwargs)def call(self, x): #上一层一般就是embedding层,batch_size,seq_len,model_dimif (self.size == None) or (self.mode == 'sum'):self.size = int(x.shape[-1]) #d_model的长度,比如512batch_size,seq_len = K.shape(x)[0],K.shape(x)[1] ### K.arange(self.size / 2, dtype='float32' ), 生成0~256,间隔1,即公式中的i## 2*K.arange(self.size / 2, dtype='float32' ), 0~512,间隔2,即公式中的2i, 0,2,4,6……,512,对应的i是0,1,2,3,4,5## 再除以model_dim,按公式取powposition_j = 1. / K.pow(10000., 2 * K.arange(self.size / 2, dtype='float32' ) / self.size) #position_j = K.expand_dims(position_j, 0) # (1,256)#生成位置的序列#x[:,:,0]取每个embedding的第一个分量---> bs,seq_len#ones_like -->bs,seq_len [[1,1,1,1……],[1,1,1……],……]#cumsum ---> bs,seq_len,[[1,2,3,4……],[1,2,3……],……]#cumsum-1 ----->bs,seq_len,[[0,1,2,3……],[0,1,2……],……]position_i = K.cumsum(K.ones_like(x[:,:,0]), 1)-1 #K.arange不支持变长,只好用这种方法生成position_i = K.expand_dims(position_i, 2)#bs,seq_len,1position_ij = K.dot(position_i, position_j)#bs,seq_len,256##经过dot之后,就是pe/10000^(2i/d_model)了##原始的实现稍微有点问题,不应该直接concatenate偶数和奇数,应该交叉concatenateposition_ij_2i = K.sin(position_ij)[...,tf.newaxis] #bs,seq_len,model_dim/2,1position_ij_2i_1 = K.cos(postition_ij)[...,tf.newaxis]#bs,seq_len,model_dim/2,1position_ij = K.concatenate([position_ij_2i,position_ij_2i_1])#bs,seq_len,model_dim/2,2position_ij = K.reshape(position_ij,(batch_size,seq_len,self.size)) #bs,seq_len,model_dim#position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)#这个实现没有交叉拼接,前半部分都用的cos,后半部分都用的sinif self.mode == 'sum':return position_ij + xelif self.mode == 'concat':return K.concatenate([position_ij, x], 2)def compute_output_shape(self, input_shape):if self.mode == 'sum':return input_shapeelif self.mode == 'concat':return (input_shape[0], input_shape[1], input_shape[2]+self.size)

单个block的各自实现

multi-head attention

首先,我们需要先实现单个的attention,如果不想按单个单个的attention实现,可以参考/xiaosongshine/article/details/86595847的attention层快速实现多个attention,不过需要添加一个Wo才能和论文完全一致,这里为了保证跟论文一致且拆解更清晰,我们先实现单个attention。

scaled dot attention

看一下scaled dot attention的示意图及公式:

定义Wq,Wk,Wv三个矩阵分别用三个矩阵相乘得到Q,K ,VQ,K dot得到分数,算softmax权重权重 * V矩阵得到最后的加权后的V矩阵(H矩阵)特别的是算softmax的时候要除以一个 D k \sqrt{D_{k}} Dk​ ​,具体原因见/qq_37430422/article/details/105042303

代码实现:

class ScaledDotProductAttention(Layer):r"""The attention layer that takes three inputs representing queries, keys and values.\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) VSee: /pdf/1706.03762.pdf"""def __init__(self,return_attention=False,history_only=False,**kwargs):"""Initialize the layer.:param return_attention: Whether to return attention weights.:param history_only: Whether to only use history data.:param kwargs: Arguments for parent class."""super(ScaledDotProductAttention, self).__init__(**kwargs)self.supports_masking = Trueself.return_attention = return_attentionself.history_only = history_onlyself.intensity = self.attention = Nonedef get_config(self):config = {'return_attention': self.return_attention,'history_only': self.history_only,}base_config = super(ScaledDotProductAttention, self).get_config()return dict(list(base_config.items()) + list(config.items()))def compute_output_shape(self, input_shape):if isinstance(input_shape, list):query_shape, key_shape, value_shape = input_shapeelse:query_shape = key_shape = value_shape = input_shapeoutput_shape = query_shape[:-1] + value_shape[-1:]if self.return_attention:attention_shape = query_shape[:2] + (key_shape[1],)return [output_shape, attention_shape]return output_shapedef compute_mask(self, inputs, mask=None):if isinstance(mask, list):mask = mask[0]if self.return_attention:return [mask, None]return maskdef call(self, inputs, mask=None, **kwargs):if isinstance(inputs, list):query, key, value = inputselse:query = key = value = inputsif isinstance(mask, list):mask = mask[1]feature_dim = K.shape(query)[-1] #512#query = (bs,seq_len,dim)#key = (bs,seq_len,dim)#batch_dot后bs,seq_len,seq_lene = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))if self.history_only:query_len, key_len = K.shape(query)[1], K.shape(key)[1]indices = K.expand_dims(K.arange(0, key_len), axis=0)upper = K.expand_dims(K.arange(0, query_len), axis=-1)e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0)if mask is not None:e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx()))self.intensity = ee = K.exp(e - K.max(e, axis=-1, keepdims=True))self.attention = e / K.sum(e, axis=-1, keepdims=True)#self.attention = bs,seq_len,seq_len#value = bs,seq_len,dim#v = bs,seq_len,dimv = K.batch_dot(self.attention, value)if self.return_attention:return [v, self.attention]return v

multi-head attention

这个实现其实就是比较简单的了,把Q,K,V先映射一遍,然后切成num_head个块之后,再分别通过前面实现的scaled dot attention最后合并,然后再做一个映射即可,用Q举例看一下示意图:

(1)假设Q(bs=1,seq_len=10,dim=512)已经过了一个映射层,得到Q_的示意

(2)同理得到的K_,计算Q_和K_计算dot attention矩阵

(3)同理得到V_,加权求和Outputs

(4)reshape回去

(5)最后,再过一次Wo

代码实现:

class MultiHeadAttention(Layer):"""Multi-head attention layer.See: /pdf/1706.03762.pdf"""def __init__(self,head_num,activation='relu',use_bias=True,kernel_initializer='glorot_normal',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,kernel_constraint=None,bias_constraint=None,history_only=False,**kwargs):"""Initialize the layer.:param head_num: Number of heads.:param activation: Activations for linear mappings.:param use_bias: Whether to use bias term.:param kernel_initializer: Initializer for linear mappings.:param bias_initializer: Initializer for linear mappings.:param kernel_regularizer: Regularizer for linear mappings.:param bias_regularizer: Regularizer for linear mappings.:param kernel_constraint: Constraints for linear mappings.:param bias_constraint: Constraints for linear mappings.:param history_only: Whether to only use history in attention layer."""self.supports_masking = Trueself.head_num = head_numself.activation = keras.activations.get(activation)self.use_bias = use_biasself.kernel_initializer = keras.initializers.get(kernel_initializer)self.bias_initializer = keras.initializers.get(bias_initializer)self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)self.bias_regularizer = keras.regularizers.get(bias_regularizer)self.kernel_constraint = keras.constraints.get(kernel_constraint)self.bias_constraint = keras.constraints.get(bias_constraint)self.history_only = history_onlyself.Wq = self.Wk = self.Wv = self.Wo = Noneself.bq = self.bk = self.bv = self.bo = Noneself.intensity = self.attention = Nonesuper(MultiHeadAttention, self).__init__(**kwargs)def get_config(self):config = {'head_num': self.head_num,'activation': keras.activations.serialize(self.activation),'use_bias': self.use_bias,'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),'bias_initializer': keras.initializers.serialize(self.bias_initializer),'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),'bias_constraint': keras.constraints.serialize(self.bias_constraint),'history_only': self.history_only,}base_config = super(MultiHeadAttention, self).get_config()return dict(list(base_config.items()) + list(config.items()))def compute_output_shape(self, input_shape):if isinstance(input_shape, list):q, k, v = input_shapereturn q[:-1] + (v[-1],)return input_shapedef compute_mask(self, inputs, input_mask=None):if isinstance(input_mask, list):return input_mask[0]return input_maskdef build(self, input_shape):if isinstance(input_shape, list):q, k, v = input_shapeelse:q = k = v = input_shapefeature_dim = int(v[-1])if feature_dim % self.head_num != 0:raise IndexError('Invalid head number %d with the given input dim %d' % (self.head_num, feature_dim))self.Wq = self.add_weight(shape=(int(q[-1]), feature_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,name='%s_Wq' % self.name,)if self.use_bias:self.bq = self.add_weight(shape=(feature_dim,),initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,name='%s_bq' % self.name,)self.Wk = self.add_weight(shape=(int(k[-1]), feature_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,name='%s_Wk' % self.name,)if self.use_bias:self.bk = self.add_weight(shape=(feature_dim,),initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,name='%s_bk' % self.name,)self.Wv = self.add_weight(shape=(int(v[-1]), feature_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,name='%s_Wv' % self.name,)if self.use_bias:self.bv = self.add_weight(shape=(feature_dim,),initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,name='%s_bv' % self.name,)self.Wo = self.add_weight(shape=(feature_dim, feature_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,name='%s_Wo' % self.name,)if self.use_bias:self.bo = self.add_weight(shape=(feature_dim,),initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,name='%s_bo' % self.name,)super(MultiHeadAttention, self).build(input_shape)@staticmethoddef _reshape_to_batches(x, head_num):#split to head numinput_shape = K.shape(x)batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]head_dim = feature_dim // head_numx = K.reshape(x, (batch_size, seq_len, head_num, head_dim))##为了方便scaled dot attention 计算(输入是bs, seq_len,head_dim),这里做了transpose和reshapex = K.permute_dimensions(x, [0, 2, 1, 3]) #transpose,把并行计算的head_num维度提到前面return K.reshape(x, (batch_size * head_num, seq_len, head_dim)) #reshape,因为bs轴在scaled dot里面不参与计算@staticmethoddef _reshape_attention_from_batches(x, head_num):##attention得分矩阵的反向恢复input_shape = K.shape(x)batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))return K.permute_dimensions(x, [0, 2, 1, 3])@staticmethoddef _reshape_from_batches(x, head_num):#attention后的向量恢复input_shape = K.shape(x)batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] #bs*head_num,seq_len,head_dimx = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))#bs,head_num,seq_len,head_dimx = K.permute_dimensions(x, [0, 2, 1, 3])#bs,seq_len,head_num,head_dimreturn K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num)) #bs,seq_len,model_dim@staticmethoddef _reshape_mask(mask, head_num):if mask is None:return mask seq_len = K.shape(mask)[1]mask = K.expand_dims(mask, axis=1)mask = K.tile(mask, [1, head_num, 1])return K.reshape(mask, (-1, seq_len))def call(self, inputs, mask=None):if isinstance(inputs, list):q, k, v = inputselse:q = k = v = inputs #bs,seq_len,model_dimif isinstance(mask, list):q_mask, k_mask, v_mask = maskelse:q_mask = k_mask = v_mask = maskq = K.dot(q, self.Wq) #先做变换再分成8个,和先分成8*64个再做变换,参数量都是一样的512*512k = K.dot(k, self.Wk)v = K.dot(v, self.Wv)if self.use_bias:q += self.bqk += self.bkv += self.bvif self.activation is not None:q = self.activation(q)k = self.activation(k)v = self.activation(v)scaled_dot_product_attention = ScaledDotProductAttention(history_only=self.history_only,name='%s-Attention' % self.name,)y = scaled_dot_product_attention(inputs=[self._reshape_to_batches(q, self.head_num), #query,bs*numhead,seq_len,dim,head_dimself._reshape_to_batches(k, self.head_num), #keyself._reshape_to_batches(v, self.head_num), #value],mask=[self._reshape_mask(q_mask, self.head_num),self._reshape_mask(k_mask, self.head_num),self._reshape_mask(v_mask, self.head_num),],)# 相似度矩阵# self.intensity = self._reshape_attention_from_batches(scaled_dot_product_attention.intensity, self.head_num)# self.attention = self._reshape_attention_from_batches(scaled_dot_product_attention.attention, self.head_num)y = self._reshape_from_batches(y, self.head_num) #合并y = K.dot(y, self.Wo) #最终输出if self.use_bias:y += self.boif self.activation is not None:y = self.activation(y)# Add shape information to tensorinput_shape = [K.int_shape(q), K.int_shape(k), K.int_shape(v)]output_shape = pute_output_shape(input_shape)if output_shape[1] is not None:output_shape = (-1,) + output_shape[1:]y = K.reshape(y, output_shape)return y

LayerNorm

代码:

class LayerNorm(Layer):def __init__(self,center=True,scale=False,epsilon=None,gamma_initializer='ones',beta_initializer='zeros',gamma_regularizer=None,beta_regularizer=None,gamma_constraint=None,beta_constraint=None,**kwargs):super(LayerNorm, self).__init__(**kwargs)self.supports_masking = Trueself.center = centerself.scale = scaleif epsilon is None:epsilon = K.epsilon() * K.epsilon()self.epsilon = epsilonself.gamma_initializer = keras.initializers.get(gamma_initializer)self.beta_initializer = keras.initializers.get(beta_initializer)self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)self.beta_regularizer = keras.regularizers.get(beta_regularizer)self.gamma_constraint = keras.constraints.get(gamma_constraint)self.beta_constraint = keras.constraints.get(beta_constraint)self.gamma, self.beta = 0., 0.def call(self, inputs, **kwargs):mean = K.mean(inputs, axis=-1, keepdims=True)variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)std = K.sqrt(variance + self.epsilon)outputs = (inputs - mean) / stdif self.scale:outputs *= self.gammaif self.center:outputs += self.betareturn outputs

加上Add、FFN,形成一个完整的transformer block

def transformer_block(x,prefix):O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x) #bs,words_len,dimO_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1) #X = LayerNorm(X + multihead(X))O_seq_fc1 = Dense(model_dim * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1) #FFNO_seq_fc2 = Dense(model_dim,name=f'{prefix}_fc2')(O_seq_fc1)O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])#O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)return O_seq_LN2

完整模型定义

MAX_LEN = 512MODEL_DIM = 512def load_word_embedding(filepath):embeddings_index = {}f = open(filepath, encoding='utf8')for line in tqdm(f):values = line.split()word = ''.join(values[:-MODEL_DIM])coefs = np.asarray(values[-MODEL_DIM:], dtype='float32')embeddings_index[word] = coefsf.close()return embeddings_indexdef build_matrix(word_index, path):embedding_index = load_word_embedding(path) embedding_matrix = np.zeros((len(word_index) + 1, MODEL_DIM))for word, i in word_index.items():if word in embedding_index:embedding_matrix[i] = embedding_index[word]#breakreturn embedding_matrixdef transformer_block(x,prefix):O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x) #bs,words_len,dimO_seq = Dropout(0.1,name=f'{prefix}_do1')(O_seq)O_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1) #X = LayerNorm(X + multihead(X))O_seq_fc1 = Dense(MODEL_DIM * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1) #FFNO_seq_fc2 = Dense(MODEL_DIM,name=f'{prefix}_fc2')(O_seq_fc1)O_seq_fc2 = Dropout(0.1,name=f'{prefix}_do2')(O_seq_fc2)O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])#O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)return O_seq_LN2def build_model(embedding_matrix, num_class = 2):words = Input(shape=(MAX_LEN,),name='inputs',dtype='int32')embeddings = Embedding(*embedding_matrix.shape, weights=[embedding_matrix], trainable=True)(words)embeddings = Position_Embedding()(embeddings) #增加Position_Embedding能轻微提高准确率embeddings = Dropout(0.1)(embeddings)# def transformer_block(x,prefix):seq_len = K.shape(words)[1]#model_dim = K.int_shape(embeddings)[-1]O_seq1 = transformer_block(embeddings,prefix='t1')O_seq2 = transformer_block(O_seq1,prefix='t2')O_seq3 = transformer_block(O_seq2,prefix='t3')O_seq4 = transformer_block(O_seq3,prefix='t4')O_seq5 = transformer_block(O_seq4,prefix='t5')O_seq6 = transformer_block(O_seq5,prefix='t6')#O_seq7 = transformer_block(O_seq6,prefix='t7')#O_seq8 = transformer_block(O_seq7,prefix='t8')O_seq = Add()([O_seq4,O_seq5,O_seq6]) ###后面这块是自由发挥的O_seq = GlobalAveragePooling1D()(O_seq)O_seq = Dropout(0.1)(O_seq)#下面的这块原文用了warmup,我们不用了。result = Dense(num_class, activation='softmax', name='outputs')(O_seq) model = Model(inputs=words, outputs=result)opt=keras.optimizers.Adam(lr=5e-5)pile(loss='categorical_crossentropy',optimizer=opt, metrics=['acc'])model.summary()return model

题外话

如果只用上面的这些代码来跑模型,你可能会发现模型收敛很困难,因为没有做learning rate的warm up,而这其实是很重要的,如果发现模型不收敛,可以尝试把LayerNorm放到attention和FFN之前,或者先尝试把Learning rate调小一点(5e-5及以下),还可以加上warmup策略。

参考:/p/84614490

附上keras的warmup的实现,来源:

/yangyin/keras_classfication/blob/master/warmup_cosine_decay_scheduler.py

可以自己根据需要修改:

import numpy as npfrom tensorflow import kerasfrom keras import backend as K# 带有warm-up的cosine学习率def cosine_decay_with_warmup(global_step,learning_rate_base,total_steps,warmup_learning_rate=0.0,warmup_steps=0,hold_base_rate_steps=0):"""Cosine decay schedule with warm up period.Cosine annealing learning rate as described in:Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.ICLR . /abs/1608.03983In this schedule, the learning rate grows linearly from warmup_learning_rateto learning_rate_base for warmup_steps, then transitions to a cosine decayschedule.Arguments:global_step {int} -- global step.learning_rate_base {float} -- base learning rate.total_steps {int} -- total number of training steps.Keyword Arguments:warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})warmup_steps {int} -- number of warmup steps. (default: {0})hold_base_rate_steps {int} -- Optional number of steps to hold base learning ratebefore decaying. (default: {0})Returns:a float representing learning rate.Raises:ValueError: if warmup_learning_rate is larger than learning_rate_base,or if warmup_steps is larger than total_steps."""if total_steps < warmup_steps:raise ValueError('total_steps must be larger or equal to ''warmup_steps.')learning_rate = 0.5 * learning_rate_base * (1 + np.cos(np.pi *(global_step - warmup_steps - hold_base_rate_steps) / float(total_steps - warmup_steps - hold_base_rate_steps)))if hold_base_rate_steps > 0:learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,learning_rate, learning_rate_base)if warmup_steps > 0:if learning_rate_base < warmup_learning_rate:raise ValueError('learning_rate_base must be larger or equal to ''warmup_learning_rate.')slope = (learning_rate_base - warmup_learning_rate) / warmup_stepswarmup_rate = slope * global_step + warmup_learning_ratelearning_rate = np.where(global_step < warmup_steps, warmup_rate,learning_rate)return np.where(global_step > total_steps, 0.0, learning_rate)class WarmUpCosineDecayScheduler(keras.callbacks.Callback):"""Cosine decay with warmup learning rate scheduler"""def __init__(self,learning_rate_base,total_steps,global_step_init=0,warmup_learning_rate=0.0,warmup_steps=0,hold_base_rate_steps=0,verbose=0):"""Constructor for cosine decay with warmup learning rate scheduler.Arguments:learning_rate_base {float} -- base learning rate.total_steps {int} -- total number of training steps.Keyword Arguments:global_step_init {int} -- initial global step, e.g. from previous checkpoint.warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})warmup_steps {int} -- number of warmup steps. (default: {0})hold_base_rate_steps {int} -- Optional number of steps to hold base learning ratebefore decaying. (default: {0})verbose {int} -- 0: quiet, 1: update messages. (default: {0})"""super(WarmUpCosineDecayScheduler, self).__init__()self.learning_rate_base = learning_rate_baseself.total_steps = total_stepsself.global_step = global_step_initself.warmup_learning_rate = warmup_learning_rateself.warmup_steps = warmup_stepsself.hold_base_rate_steps = hold_base_rate_stepsself.verbose = verboseself.learning_rates = []def on_batch_end(self, batch, logs=None):self.global_step = self.global_step + 1lr = K.get_value(self.model.optimizer.lr)self.learning_rates.append(lr)def on_batch_begin(self, batch, logs=None):lr = cosine_decay_with_warmup(global_step=self.global_step,learning_rate_base=self.learning_rate_base,total_steps=self.total_steps,warmup_learning_rate=self.warmup_learning_rate,warmup_steps=self.warmup_steps,hold_base_rate_steps=self.hold_base_rate_steps)K.set_value(self.model.optimizer.lr, lr)if self.verbose > 0:print('\nBatch %05d: setting learning ''rate to %s.' % (self.global_step + 1, lr))if __name__ == '__main__':from keras.models import Sequentialfrom keras.layers import Dense# Create a model.model = Sequential()model.add(Dense(32, activation='relu', input_dim=100))model.add(Dense(10, activation='softmax'))pile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])# Number of training samples.# gen1sample_count = 12608# gen# Total epochs to train.epochs = 50# Number of warmup epochs.warmup_epoch = 10# Training batch size, set small value here for demonstration purpose.batch_size = 16# Base learning rate after warmup.learning_rate_base = 0.0001total_steps = int(epochs * sample_count / batch_size)# Compute the number of warmup batches.warmup_steps = int(warmup_epoch * sample_count / batch_size)# Generate dummy data.data = np.random.random((sample_count, 100))labels = np.random.randint(10, size=(sample_count, 1))# Convert labels to categorical one-hot encoding.one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)# Compute the number of warmup batches.warmup_batches = warmup_epoch * sample_count / batch_size# Create the Learning rate scheduler.warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,total_steps=total_steps,warmup_learning_rate=4e-06,warmup_steps=warmup_steps,hold_base_rate_steps=5,)# Train the model, iterating on the data in batches of 32 samplesmodel.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,verbose=0, callbacks=[warm_up_lr])import matplotlib.pyplot as pltplt.plot(warm_up_lr.learning_rates)plt.xlabel('Step', fontsize=20)plt.ylabel('lr', fontsize=20)plt.axis([0, total_steps, 0, learning_rate_base*1.1])plt.xticks(np.arange(0, epochs, 1))plt.grid()plt.title('Cosine decay with warmup', fontsize=20)plt.show()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。