机器学习之机器学习:如何找到最优学习率
小标 2018-11-22 来源 : 阅读 1467 评论 0

摘要:本文主要向大家介绍了机器学习之机器学习:如何找到最优学习率,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习之机器学习:如何找到最优学习率,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。


 学习率的重要性


目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

[Math Processing Error]w:=w−α∂∂wloss(w)

这里[Math Processing Error]α就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。


学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

 一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。


从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。


 实现


上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重。

下面就是部分代码,这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码如下:

class ScheduledOptim(object):
    '''A wrapper class for learning rate scheduling'''

    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.lr = self.optimizer.param_groups[0]['lr']
        self.current_steps = 0

    def step(self):
        "Step by the inner optimizer"
        self.current_steps += 1
        self.optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self.optimizer.zero_grad()

    def set_learning_rate(self, lr):
        self.lr = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    @property
    def learning_rate(self):
        return self.lr

def find_lr():
    pass

def train_model():
    pass


 整体代码如下

criterion = torch.nn.CrossEntropyLoss()  

net = model_zoo.resnet50(pretrained=True)  

net.fc = nn.Linear(2048, 120)  

with torch.cuda.device(0):  

    net = net.cuda()  

basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)  

optimizer = ScheduledOptim(basic_optim)  

lr_mult = (1 / 1e-5) ** (1 / 100)  

lr = []  

losses = []  

best_loss = 1e9  

for data, label in train_data:  

    with torch.cuda.device(0):  

        data = Variable(data.cuda())  

        label = Variable(label.cuda())  

    # forward  

    out = net(data)  

    loss = criterion(out, label)  

    # backward  

    optimizer.zero_grad()  

    loss.backward()  

    optimizer.step()  

    lr.append(optimizer.learning_rate)  

    losses.append(loss.data[0])  

    optimizer.set_learning_rate(optimizer.learning_rate * lr_mult)  

    if loss.data[0] < best_loss:  

        best_loss = loss.data[0]  

    if loss.data[0] > 4 * best_loss or optimizer.learning_rate > 1.:  

        break  

plt.figure()  

plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))  

plt.xlabel('learning rate')  

plt.ylabel('loss')  

plt.plot(np.log(lr), losses)  

plt.show()  

plt.figure()  

plt.xlabel('num iterations')  

plt.ylabel('learning rate')  

plt.plot(lr) 


本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!


本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved