机器学习入门之机器学习笔记-多类逻辑回归-使用gluon
小标 2018-12-18 来源 : 阅读 1343 评论 0

摘要:本文主要向大家介绍了机器学习入门之机器学习笔记-多类逻辑回归-使用gluon,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习入门之机器学习笔记-多类逻辑回归-使用gluon,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

这次使用gluon让代码更精减

from mxnet import gluon

from mxnet import ndarray as nd

import matplotlib.pyplot as plt

import mxnet as mx

from mxnet import autograd

   

def transform(data, label):

    return data.astype('float32')/255, label.astype('float32')

   

mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)

mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)

   

def show_images(images):

    n = images.shape[0]

    _, figs = plt.subplots(1, n, figsize=(15, 15))

    for i in range(n):

        figs[i].imshow(images[i].reshape((28, 28)).asnumpy())

        figs[i].axes.get_xaxis().set_visible(False)

        figs[i].axes.get_yaxis().set_visible(False)

    plt.show()

 

def get_text_labels(label):

    text_labels = [

        'T 恤', '长 裤', '套头衫', '裙 子', '外 套',

        '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'

    ]

    return [text_labels[int(i)] for i in label]

   

data, label = mnist_train[0:10]

   

print('example shape: ', data.shape, 'label:', label)

show_images(data)

print(get_text_labels(label))

   

batch_size = 256

train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)

test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)

   

#计算模型

net = gluon.nn.Sequential()

with net.name_scope():

    net.add(gluon.nn.Flatten())

    net.add(gluon.nn.Dense(256, activation="relu"))

    net.add(gluon.nn.Dense(10))

net.initialize()

   

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

 

#定义训练器

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})

  

def accuracy(output, label):

    return nd.mean(output.argmax(axis=1) == label).asscalar()

   

def _get_batch(batch):

    if isinstance(batch, mx.io.DataBatch):

        data = batch.data[0]

        label = batch.label[0]

    else:

        data, label = batch

    return data, label

   

def evaluate_accuracy(data_iterator, net):

    acc = 0.

    if isinstance(data_iterator, mx.io.MXDataIter):

        data_iterator.reset()

    for i, batch in enumerate(data_iterator):

        data, label = _get_batch(batch)

        output = net(data)

        acc += accuracy(output, label)

    return acc / (i+1)

   

for epoch in range(5):

    train_loss = 0.

    train_acc = 0.

    for data, label in train_data:

        with autograd.record():

            output = net(data)

            loss = softmax_cross_entropy(output, label)

        loss.backward()

        trainer.step(batch_size) #使用训练器,向"前"走一步

 

        train_loss += nd.mean(loss).asscalar()

        train_acc += accuracy(output, label)

 

    test_acc = evaluate_accuracy(test_data, net)

    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (

        epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))

 

data, label = mnist_test[0:10]

show_images(data)

print('true labels')

print(get_text_labels(label))

   

predicted_labels = net(data).argmax(axis=1)

print('predicted labels')

print(get_text_labels(predicted_labels.asnumpy()))

 有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了

本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved