关于spark的mllib学习总结(Java版)

hongduna 2018-08-06

本篇主要讲述如何利用spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所示:

关于spark的mllib学习总结(Java版)

加载数据

对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示:

关于spark的mllib学习总结(Java版)

加载libsvm

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();

  • 1

LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为:

Lable(double类型),vector(Vector类型)

转化dataFrame数据类型

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());

StructType schema = new StructType(new StructField[]{

new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),

new StructField("features", new VectorUDT(), false, Metadata.empty()),

});

SQLContext jsql = new SQLContext(sc);

DataFrame df = jsql.createDataFrame(jrow, schema);

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、hive表、外部数据库或者已经存在的RDD构造。

SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。

特征提取

特征归一化处理

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);

DataFrame scalerDF = scaler.fit(df).transform(df);

scaler.save(this.scalerModelPath);

  • 1
  • 2
  • 3

利用卡方统计做特征提取

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");

ChiSqSelectorModel chiModel = selector.fit(scalerDF);

DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

chiModel.save(this.featureSelectedModelPath);

  • 1
  • 2
  • 3
  • 4

训练机器学习模型(以SVM为例)

//转化为LabeledPoint数据类型, 训练模型

JavaRDD<Row> selectedrows = selectedDF.javaRDD();

JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());

//训练SVM模型, 并保存

int numIteration = 200;

SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);

model.clearThreshold();

model.save(sc, this.mlModelPath);

// LabeledPoint数据类型转化为Row

static class LabeledPointToRow implements Function<LabeledPoint, Row> {

public Row call(LabeledPoint p) throws Exception {

double label = p.label();

Vector vector = p.features();

return RowFactory.create(label, vector);

}

}

//Rows数据类型转化为LabeledPoint

static class RowToLabel implements Function<Row, LabeledPoint> {

public LabeledPoint call(Row r) throws Exception {

Vector features = r.getAs(1);

double label = r.getDouble(0);

return new LabeledPoint(label, features);

}

}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

测试新的样本

测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

//初始化spark

SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");

conf.set("spark.testing.memory", "2147480000");

SparkContext sc = new SparkContext(conf);

//加载测试数据

JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();

//转化DataFrame数据类型

JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());

StructType schema = new StructType(new StructField[]{

new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),

new StructField("features", new VectorUDT(), false, Metadata.empty()),

});

SQLContext jsql = new SQLContext(sc);

DataFrame df = jsql.createDataFrame(jrow, schema);

//数据规范化

StandardScaler scaler = StandardScaler.load(this.scalerModelPath);

DataFrame scalerDF = scaler.fit(df).transform(df);

//特征选取

ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);

DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

测试数据集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);

JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;

predictResult.collect();

static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {

SVMModel model;

public Prediction(SVMModel model){

this.model = model;

}

public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {

Double score = model.predict(p.features());

return new Tuple2<Double , Double>(score, p.label());

}

}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

计算准确率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();

System.out.println(accuracy);

static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {

public Boolean call(Tuple2<Double, Double> t) throws Exception {

double score = t._1();

double label = t._2();

System.out.print("score:" + score + ", label:"+ label);

if(score >= 0.0 && label >= 0.0) return true;

else if(score < 0.0 && label < 0.0) return true;

else return false;

}

}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

具体的代码,放在我的github上:https://github.com/Quincy1994/MachineLearning/

相关推荐