我们在一个深度模型的训练中经常会用到回调函数来对训练过程进行监测,使得训练过程更加智能化。

例如,我们经常使用的早停机制:

from tensorflow.keras.callbacks import EarlyStopping

early_stop = EarlyStopping(monitor='val_loss', 
                           mode='min', 
                           patience=10, 
                           restore_best_weights=True, 
                           verbose=1)

通过监测验证误差的变化趋势,我们可以在验证误差不再增长的时候提前结束训练。

另一个与 EarlyStopping 常常配合使用的是 ReduceLROnPlateau,当指定的训练误差或者验证误差在指定的轮次以内不再增长的时候,我们将学习率根据设置的衰减系数 factor 自动降低:

from keras.callbacks import ReduceLROnPlateau

learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.00001)

keras 的回调 API 包含许多不同功能用途的回调函数,通常这些回调就可以满足我们的需求了。但如果我们想要更加精细的控制训练过程,可能需要写一个自己的回调。

我们接下来就实现一个,可以在训练过程中自由控制训练轮次和学习率的回调。这个回调的功能主要是:

  • 在指定轮次之后询问使用者,是否继续训练,如果继续训练,键入继续训练的轮次,并选择保持或者改变当前学习率
  • 如果验证误差增加,则自动调整学习率,且模型加载当前最优的权重
  • 训练结束后,直接让模型加载最优权重

我们需要定义一个类,这个类继承 keras.callbacks.Callback,然后做一些初始化:

class My_ASK(keras.callbacks.Callback):
    def __init__(self, model, epochs, ask_epoch, dwell=True, factor=.4):
        super(My_ASK, self).__init__()
        self.model = model
        
        """
			模型在训练 ask_epoch 之后,会让使用者选择是暂停训练还是继续训练,
			如果继续训练,则直接输入一个整数,表明继续训练的轮次,且会给我们
			修改学习率的机会
		"""
        
        self.ask_epoch = ask_epoch
        
        self.epochs = epochs
        self.ask = True # 将 ask 设为 True 才会有上面 ask_epoch 描述的询问
        self.lowest_vloss = np.inf
        self.lowest_loss = np.inf
        self.best_weights = self.model.get_weights() # 最优权重初始化为模型的初始权重
        self.best_epoch = 1
        self.vlist = [] # 存储验证误差变化的列表
        self.tlist = [] # 存储训练误差变化的列表
        self.dwell = dwell
        self.factor = factor # 学习率衰减系数
        

通常一个回调中的方法有 on_train_begin, on_train_end, on_epoch_end, on_epoch_begin 等,它们并不是需要全部定义,我们可以根据自己的实际需求进行选择。我们定义的这个类就只使用了 on_train_begin, on_train_end, on_epoch_end 三种方法。我们来看看这三种方法都具体做了些什么。

训练开始时,会给我们报告一些参数设置的情况,提示我们模型的训练流程,同时启动计时器。

    def on_train_begin(self, logs = None):
        if self.ask_epoch == 0:
            print('You set ask_epoch = 0, ask_epoch will be set to 1', flush = True)
            self.ask_epoch = 1
        if self.ask_epoch >= self.epochs: # 如果设置的 ask_epoch 比 epochs 还大,那就没有意义了
            print('ask_epoch >= epochs, will train for ', epochs, ' epochs', flush=True)
            self.ask = False
        if self.epochs == 1:
            self.ask = False
        else:
             
            print(f'Training will proceed until epoch {ask_epoch} then you will be asked to')
            print('enter H to halt training or enter an integer for how many more epochs to run then be asked again')
            
            if self.dwell:
                print('\n Learning rate will be automatically adjusted during training')
                
        self.start_time = time.time() # 开始计时

训练结束后,模型会加载最优权重,并返回训练的总时间。

    def on_train_end(self, logs=None):
        print(f'Loading model with weights from epoch {self.best_epoch}')
        
        self.model.set_weights(self.best_weights)
        train_duration = time.time() - self.start_time
        hours = train_duration // 3600
        minutes = (train_duration - hours * 3600) // 60
        seconds = train_duration - hours * 3600 - minutes * 60

        print(f'Training using {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds')

可以看到,训练开始和训练结束的方法内容非常简单,如果不考虑可读性,那么省略不写也不会有太大影响。重点是下面的 on_epoch_end 方法。注释以及代码打印的内容已经很详细了,一行一行看下去肯定是没有问题的,这里不再过多解释。

    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs.get('val_loss')
        loss = logs.get('loss')
        if epoch > 0:
            delta_v = self.lowest_vloss - val_loss # 该轮次的验证损失和最低验证损失的差值
            vimprov = (delta_v / self.lowest_vloss) * 100 # percentage of improvement,当然也有可能是负数,表示误差增高了
            self.vlist.append(vimprov)
            
            delta_t = self.lowest_loss - loss
            timprov = (delta_t / self.lowest_loss) * 100
            self.tlist.append(timprov)
        else:
            vimprov = 0.0
            timprov = 0.0
        
        if val_loss < self.lowest_vloss:
            self.lowest_vloss = val_loss # 更新最低验证误差
            self.best_weights = self.model.get_weights() # 以及相应的权重
            self.best_epoch = epoch + 1
            print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % below lowest loss, saving weights from epoch {str(epoch + 1):3s} as best weights')
        else:
            vimprov = abs(vimprov)
            print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % above lowest loss of {self.lowest_vloss:7.4f}. Keeping weights from epoch {str(self.best_epoch)} as best weights')
            
            if self.dwell:
                lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
                new_lr = lr * self.factor
                print(f'\n Learning rate was automatically adjusted from {lr:8.6f} to {new_lr:8.6f}, model weights set to best weights')
                
                tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                self.model.set_weights(self.best_weights) # 在新的学习率基础上,看模型在最优权重上表现如何
        
        if loss < self.lowest_loss:
            self.lowest_loss = loss
            
        if self.ask:
            if epoch + 1 == self.ask_epoch:
                print('\n Enter H to end training or an integer for the number of additional epochs to run then ask again')
                ans = input()
                
                if ans == 'H' or ans == 'h' or ans == '0': # 放弃训练
                    self.model.stop_training = True
                else:
                    self.ask_epoch += int(ans) # 在第 ask_epoch+ans 轮次再次询问
                    if self.ask_epoch > self.epochs:
                        print('\n Your specification exceeds ', self.epochs, ' cannot train for ', self.ask_epoch, flush =True)
                    else:
                        print(f'\n You entered {ans}. Training will continue to epoch {self.ask_epoch}')
                        
                        if self.dwell == False:
                            lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) 
                            print(f'\n Current LR is  {lr:8.6f}  hit enter to keep  this LR or enter a new LR')
                            
                            ans = input(' ')
                            if ans == '':
                                print(f'\n Keeping current LR of {lr:7.5f}')
                                
                            else:
                                new_lr = float(ans)
                                tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                                print(f'\n Changing LR to {ans}')

事实上,这个回调实现的功能与 keras 本身含有的回调可能有相似部分,但重点在于理解一个 callback 的自定义过程。

最后,我们实例化这个回调,并添加到回调列表中。

epochs = 50
ask_epoch = 10
ask = My_ASK(model, epochs, ask_epoch)
callbacks = [ask]
11-27 10:40