摘要:本文主要向大家介绍了机器学习入门之机器学习算法:SVM(支持向量机),通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
本文主要向大家介绍了机器学习入门之机器学习算法:SVM(支持向量机),通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
SVM算法(Support Vector Machine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优,使得分类器尽可能健壮;2、如果数据线性不可分,通过核函数将低维样本转化为高维样本使其线性可分。注意和AdaBoost类似,SVM只能解决二分类问题。
SVM的算法在数学上实在是太复杂了,没研究明白。建议还是直接使用现成的第三方组件吧,比如libsvm的C#版本,推荐这个://www.matthewajohnson.org/software/svm.html。
虽然没研究明白,不过这几天照着Python版本的代码试着用C#改写了一下,算是研究SVM过程中唯一的收获吧。此版本基于SMO(序列最小优化)算法求解,核函数使用的是比较常用的径向基函数(RBF)。别问我为什么没有注释,我只是从Python移植过来的,我也没看懂,等我看懂了再来补注释吧。
using System;
using System.Collections.Generic;
using System.Linq;
namespace MachineLearning
{
/// <summary>
/// 支持向量机(SMO算法,RBF核)
/// </summary>
public class SVM
{
private Random m_Rand;
private double[][] m_Kernel;
private double[] m_Alpha;
private double m_C = 1.0;
private double m_B = 0.0;
private double m_Toler = 0.0;
private double[][] m_Cache;
private double[][] m_Data;
private double m_Reach;
private int[] m_Label;
private int m_Count;
private int m_Dimension;
public SVM()
{
m_Rand = new Random();
}
/// <summary>
/// 训练
/// </summary>
/// <param name="trainingSet"></param>
/// <param name="C"></param>
/// <param name="toler"></param>
/// <param name="reach"></param>
/// <param name="iterateCount"></param>
public void Train(List<DataVector<double, int>> trainingSet, double C, double toler, double reach, int iterateCount = 10)
{
//初始化
m_Count = trainingSet.Count;
m_Dimension = trainingSet[0].Dimension;
m_Toler = toler;
m_C = C;
m_Reach = reach;
this.Init(trainingSet);
this.InitKernel();
int iter = 0;
int alphaChanged = 0;
bool entireSet = true;
while(iter < iterateCount && (alphaChanged > 0 || entireSet))
{
alphaChanged = 0;
if(entireSet)
{
for(int i = 0;i < m_Count;++i)
alphaChanged += InnerL(i);
iter++;
}
else
{
for(int i = 0;i < m_Count;++i)
{
if(m_Alpha[i] > 0 && m_Alpha[i] < m_C)
alphaChanged += InnerL(i);
}
iter += 1;
}
if(entireSet)
entireSet = false;
else if(alphaChanged == 0)
entireSet = true;
}
}
/// <summary>
/// 分类
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
public int Classify(DataVector<double, int> vector)
{
double predict = 0.0;
int svCnt = m_Alpha.Count(a => a > 0);
var supportVectors = new double[svCnt][];
var supportLabels = new int[svCnt];
var supportAlphas = new double[svCnt];
int index = 0;
for(int i = 0;i < m_Count;++i)
{
if(m_Alpha[i] > 0)
{
supportVectors[index] = m_Data[i];
supportLabels[index] = m_Label[i];
supportAlphas[index] = m_Alpha[i];
index++;
}
}
var kernelEval = KernelTrans(supportVectors, vector.Data);
for(int i = 0;i < svCnt;++i)
predict += kernelEval[i] * supportAlphas[i] * supportLabels[i];
predict += m_B;
return Math.Sign(predict);
}
/// <summary>
/// 将原始数据转化成方便使用的形式
/// </summary>
/// <param name="trainingSet"></param>
private void Init(List<DataVector<double, int>> trainingSet)
{
m_Data = new double[m_Count][];
m_Label = new int[m_Count];
m_Alpha = new double[m_Count];
m_Cache = new double[m_Count][];
for(int i = 0;i < m_Count;++i)
{
m_Label[i] = trainingSet[i].Label;
m_Alpha[i] = 0.0;
m_Cache[i] = new double[2];
m_Cache[i][0] = 0.0;
m_Cache[i][1] = 0.0;
m_Data[i] = new double[m_Dimension];
for(int j = 0;j < m_Dimension;++j)
m_Data[i][j] = trainingSet[i].Data[j];
}
}
/// <summary>
/// 初始化RBF核
/// </summary>
private void InitKernel()
{
m_Kernel = new double[m_Count][];
for(int i = 0;i < m_Count;++i)
{
m_Kernel[i] = new double[m_Count];
var kernels = KernelTrans(m_Data, m_Data[i]);
for(int k = 0;k < kernels.Length;++k)
m_Kernel[i][k] = kernels[k];
}
}
private double[] KernelTrans(double[][] X, double[] A)
{
var kernel = new double[X.Length];
for(int i = 0;i < X.Length;++i)
{
double delta = 0.0;
for(int k = 0;k < X[0].Length;++k)
delta += Math.Pow(X[i][k] - A[k], 2);
kernel[i] = Math.Exp(delta * -1.0 / Math.Pow(m_Reach, 2));
}
return kernel;
}
private double E(int k)
{
double x = 0.0;
for(int i = 0;i < m_Count;++i)
x += m_Alpha[i] * m_Label[i] * m_Kernel[i][k];
x += m_B;
return x - m_Label[k];
}
private void UpdateE(int k)
{
double Ek = E<span class="token punctuati
本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!
您输入的评论内容中包含违禁敏感词
我知道了
请输入正确的手机号码
请输入正确的验证码
您今天的短信下发次数太多了,明天再试试吧!
我们会在第一时间安排职业规划师联系您!
您也可以联系我们的职业规划师咨询:
版权所有 职坐标-一站式IT培训就业服务领导者 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
沪公网安备 31011502005948号