机器学习入门之机器学习~用于机器学习中的分类边界、决策树等可视化的模块
小标 2018-11-20 来源 : 阅读 1897 评论 0

摘要:本文主要向大家介绍了机器学习入门之机器学习~用于机器学习中的分类边界、决策树等可视化的模块,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习入门之机器学习~用于机器学习中的分类边界、决策树等可视化的模块,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import graphviz
from sklearn.tree import export_graphviz
import matplotlib.patches as mpatches


def plot_decision_tree(clf, feature_names, class_names):
   """
       决策树结果可视化
       需要安装
       1. graphviz程序(已提供在代码目录下),并将安装目录下的bin目录添加到环境变量中,重启jupyter或系统生效
          如:C:\Program Files (x86)\Graphviz2.38\bin 添加到系统PATH环境变量中
       2. graphviz模块, pip install graphviz
   """

   tmp_dot_file = 'decision_tree_tmp.dot'
   export_graphviz(clf, out_file=tmp_dot_file, feature_names=feature_names, class_names=class_names,
                   filled=True, impurity=False)
   with open(tmp_dot_file) as f:
       dot_graph = f.read()
   # Alternate method using pydotplus, if installed.
   # graph = pydotplus.graphviz.graph_from_dot_data(dot_graph)
   # return graph.create_png()
   return graphviz.Source(dot_graph)


def plot_feature_importances(clf, feature_names):
   """
       可视化分类器中特征的重要性
   """
   c_features = len(feature_names)
   plt.barh(range(c_features), clf.feature_importances_)
   plt.xlabel('Feature importance')
   plt.ylabel('Feature name')
   plt.yticks(np.arange(c_features), feature_names)


def plot_class_regions_for_classifier(clf, X, y, X_test=None, y_test=None, title=None,
                                     target_names=None, plot_decision_regions=True):
   """
       根据分类器可视化数据分类的结果
       只能用于二维特征的数据
   """

   num_classes = np.amax(y) + 1
   color_list_light = ['#FFFFAA', '#EFEFEF', '#AAFFAA', '#AAAAFF']
   color_list_bold = ['#EEEE00', '#000000', '#00CC00', '#0000CC']
   cmap_light = ListedColormap(color_list_light[0:num_classes])
   cmap_bold = ListedColormap(color_list_bold[0:num_classes])

   h = 0.03
   k = 0.5
   x_plot_adjust = 0.1
   y_plot_adjust = 0.1
   plot_symbol_size = 50

   x_min = X[:, 0].min()
   x_max = X[:, 0].max()
   y_min = X[:, 1].min()
   y_max = X[:, 1].max()
   x2, y2 = np.meshgrid(np.arange(x_min-k, x_max+k, h), np.arange(y_min-k, y_max+k, h))

   P = clf.predict(np.c_[x2.ravel(), y2.ravel()])
   P = P.reshape(x2.shape)
   plt.figure()
   if plot_decision_regions:
       plt.contourf(x2, y2, P, cmap=cmap_light, alpha=0.8)

   plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, s=plot_symbol_size, edgecolor='black')
   plt.xlim(x_min - x_plot_adjust, x_max + x_plot_adjust)
   plt.ylim(y_min - y_plot_adjust, y_max + y_plot_adjust)

   if X_test is not None:
       plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cmap_bold, s=plot_symbol_size,
                   marker='^', edgecolor='black')
       train_score = clf.score(X, y)
       test_score = clf.score(X_test, y_test)
       title = title + "\nTrain score = {:.2f}, Test score = {:.2f}".format(train_score, test_score)

   if target_names is not None:
       legend_handles = []
       for i in range(0, len(target_names)):
           patch = mpatches.Patch(color=color_list_bold[i], label=target_names[i])
           legend_handles.append(patch)
       plt.legend(loc=0, handles=legend_handles)

   if title is not None:
       plt.title(title)
   plt.show()


def plot_fruit_knn(X, y, n_neighbors):
   """
       在“水果数据集”上对 height 和 width 二维数据进行kNN训练
       并绘制出结果
   """
   X_mat = X[['height', 'width']].as_matrix()
   y_mat = y.as_matrix()

   # Create color maps
   cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF', '#AFAFAF'])
   cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF', '#AFAFAF'])

   clf = neighbors.KNeighborsClassifier(n_neighbors)
   clf.fit(X_mat, y_mat)

   # Plot the decision boundary by assigning a color in the color map
   # to each mesh point.
   
   mesh_step_size = .01  # step size in the mesh
   plot_symbol_size = 50
   
   x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1
   y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1
   xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
                        np.arange(y_min, y_max, mesh_step_size))
   Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

   # Put the result into a color plot
   Z = Z.reshape(xx.shape)
   plt.figure()
   plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

   # Plot training points
   plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold,
               edgecolor='black')
   plt.xlim(xx.min(), xx.max())
   plt.ylim(yy.min(), yy.max())

   patch0 = mpatches.Patch(color='#FF0000', label='apple')
   patch1 = mpatches.Patch(color='#00FF00', label='mandarin')
   patch2 = mpatches.Patch(color='#0000FF', label='orange')
   patch3 = mpatches.Patch(color='#AFAFAF', label='lemon')
   plt.legend(handles=[patch0, patch1, patch2, patch3])

   plt.xlabel('height (cm)')
   plt.ylabel('width (cm)')
   
   plt.show()

本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved