机器学习入门之机器学习算法:SVM(支持向量机)
小标 2019-03-26 来源 : 阅读 975 评论 0

摘要:本文主要向大家介绍了机器学习入门之机器学习算法:SVM(支持向量机),通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习入门之机器学习算法:SVM(支持向量机),通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

机器学习入门之机器学习算法:SVM(支持向量机)

SVM算法(Support Vector Machine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优,使得分类器尽可能健壮;2、如果数据线性不可分,通过核函数将低维样本转化为高维样本使其线性可分。注意和AdaBoost类似,SVM只能解决二分类问题。


SVM的算法在数学上实在是太复杂了,没研究明白。建议还是直接使用现成的第三方组件吧,比如libsvm的C#版本,推荐这个://www.matthewajohnson.org/software/svm.html。


虽然没研究明白,不过这几天照着Python版本的代码试着用C#改写了一下,算是研究SVM过程中唯一的收获吧。此版本基于SMO(序列最小优化)算法求解,核函数使用的是比较常用的径向基函数(RBF)。别问我为什么没有注释,我只是从Python移植过来的,我也没看懂,等我看懂了再来补注释吧。

using System;
using System.Collections.Generic;
using System.Linq;

namespace MachineLearning
{
    /// <summary>
    /// 支持向量机(SMO算法,RBF核)
    /// </summary>
    public class SVM
    {
        private Random m_Rand;
        private double[][] m_Kernel;
        private double[] m_Alpha;
        private double m_C = 1.0;
        private double m_B = 0.0;
        private double m_Toler = 0.0;
        private double[][] m_Cache;
        private double[][] m_Data;
        private double m_Reach;
        private int[] m_Label;
        private int m_Count;
        private int m_Dimension;
       
        public SVM()
        {
            m_Rand = new Random();
        }
       
        /// <summary>
        /// 训练
        /// </summary>
        /// <param name="trainingSet"></param>
        /// <param name="C"></param>
        /// <param name="toler"></param>
        /// <param name="reach"></param>
        /// <param name="iterateCount"></param>
        public void Train(List<DataVector<double, int>> trainingSet, double C, double toler, double reach, int iterateCount = 10)
        {
            //初始化
            m_Count = trainingSet.Count;
            m_Dimension = trainingSet[0].Dimension;
            m_Toler = toler;
            m_C = C;
            m_Reach = reach;
            this.Init(trainingSet);
            this.InitKernel();
           
            int iter = 0;
            int alphaChanged = 0;
            bool entireSet = true;
            while(iter < iterateCount && (alphaChanged > 0 || entireSet))
            {
                alphaChanged = 0;
                if(entireSet)
                {
                    for(int i = 0;i < m_Count;++i)
                        alphaChanged += InnerL(i);
                    iter++;
                }
                else
                {
                    for(int i = 0;i < m_Count;++i)
                    {
                        if(m_Alpha[i] > 0 && m_Alpha[i] < m_C)
                            alphaChanged += InnerL(i);
                    }
                    iter += 1;
                }
               
                if(entireSet)
                    entireSet = false;
                else if(alphaChanged == 0)
                    entireSet = true;
            }
        }
       
        /// <summary>
        /// 分类
        /// </summary>
        /// <param name="vector"></param>
        /// <returns></returns>
        public int Classify(DataVector<double, int> vector)
        {
            double predict = 0.0;
           
            int svCnt = m_Alpha.Count(a => a > 0);
            var supportVectors = new double[svCnt][];
            var supportLabels = new int[svCnt];
            var supportAlphas = new double[svCnt];
            int index = 0;
            for(int i = 0;i < m_Count;++i)
            {
                if(m_Alpha[i] > 0)
                {
                    supportVectors[index] = m_Data[i];
                    supportLabels[index] = m_Label[i];
                    supportAlphas[index] = m_Alpha[i];
                    index++;
                }
            }
           
            var kernelEval = KernelTrans(supportVectors, vector.Data);
            for(int i = 0;i < svCnt;++i)
                predict += kernelEval[i] * supportAlphas[i] * supportLabels[i];
            predict += m_B;
           
            return Math.Sign(predict);
        }
       
        /// <summary>
        /// 将原始数据转化成方便使用的形式
        /// </summary>
        /// <param name="trainingSet"></param>
        private void Init(List<DataVector<double, int>> trainingSet)
        {
            m_Data = new double[m_Count][];
            m_Label = new int[m_Count];
            m_Alpha = new double[m_Count];
            m_Cache = new double[m_Count][];
           
            for(int i = 0;i < m_Count;++i)
            {
                m_Label[i] = trainingSet[i].Label;
                m_Alpha[i] = 0.0;
                m_Cache[i] = new double[2];
                m_Cache[i][0] = 0.0;
                m_Cache[i][1] = 0.0;
                m_Data[i] = new double[m_Dimension];
                for(int j = 0;j < m_Dimension;++j)
                    m_Data[i][j] = trainingSet[i].Data[j];
            }
        }
       
        /// <summary>
        /// 初始化RBF核
        /// </summary>
        private void InitKernel()
        {
            m_Kernel = new double[m_Count][];
           
            for(int i = 0;i < m_Count;++i)
            {
                m_Kernel[i] = new double[m_Count];
                var kernels = KernelTrans(m_Data, m_Data[i]);
                for(int k = 0;k < kernels.Length;++k)
                    m_Kernel[i][k] = kernels[k];
            }
        }
       
        private double[] KernelTrans(double[][] X, double[] A)
        {
            var kernel = new double[X.Length];
           
            for(int i = 0;i < X.Length;++i)
            {
                double delta = 0.0;
                for(int k = 0;k < X[0].Length;++k)
                    delta += Math.Pow(X[i][k] - A[k], 2);
                kernel[i] = Math.Exp(delta * -1.0 / Math.Pow(m_Reach, 2));
            }
           
            return kernel;
        }
       
        private double E(int k)
        {
            double x = 0.0;
            for(int i = 0;i < m_Count;++i)
                x += m_Alpha[i] * m_Label[i] * m_Kernel[i][k];
            x += m_B;
           
            return x - m_Label[k];
        }
       
        private void UpdateE(int k)
        {
            double Ek = E<span class="token punctuati    

本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式IT培训就业服务领导者 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved

208小时内训课程