搜 索

Spark从入门到放弃⑥之MLlib

  • 15阅读
  • 2023年04月22日
  • 0评论
首页 / AI/大数据 / 正文

前言:大数据时代的机器学习

以前做机器学习,数据量小,sklearn一把梭,GPU一跑,齐活。

现在呢?

老板:我们有10亿条用户行为数据,能不能做个推荐系统?
你:10亿?sklearn表示内存不够...
老板:那用什么?
你:Spark MLlib!
老板:好,下周上线。
你:...(头发-10)

Spark MLlib是Spark的机器学习库,专门为大规模分布式机器学习设计。它能处理sklearn处理不了的数据量,同时与Spark SQL、Spark Streaming无缝集成。

本篇是Spark系列的第六篇,也是最后一篇。我们将从MLlib的基础概念讲起,一直讲到实战案例,让你能够用Spark做机器学习。

友情提示:本文假设你已经有机器学习的基础知识,不会从"什么是机器学习"讲起。


一、MLlib概述

1.1 MLlib vs ML

graph TB subgraph MLlib演进["📜 MLlib演进历史"] subgraph Old["spark.mllib(旧)"] O1["基于RDD"] O2["API不友好"] O3["已停止更新"] end subgraph New["spark.ml(新)"] N1["基于DataFrame"] N2["Pipeline API"] N3["持续更新"] end Old --> |"演进"| New end style Old fill:#95a5a6 style New fill:#2ecc71

划重点

  • spark.mllib:基于RDD的旧API,已不再更新,不要用了!
  • spark.ml:基于DataFrame的新API,用这个!

本文所有代码都基于spark.ml

1.2 MLlib能做什么?

mindmap root((Spark MLlib)) 分类 逻辑回归 决策树 随机森林 梯度提升树 朴素贝叶斯 线性SVM 回归 线性回归 决策树回归 随机森林回归 梯度提升树回归 聚类 K-Means 高斯混合模型 LDA主题模型 层次聚类 协同过滤 ALS推荐 特征工程 特征提取 特征转换 特征选择 模型评估 分类指标 回归指标 聚类指标 工具 线性代数 统计 数据生成

1.3 核心概念

graph LR subgraph 核心概念["🎯 MLlib核心概念"] A[DataFrame] --> B[Transformer
转换器] B --> C[DataFrame] D[DataFrame] --> E[Estimator
估计器] E --> |"fit()训练"|F[Model] F --> |"transform()预测"|G[DataFrame] H[Pipeline
流水线] --> I["Transformer + Estimator
的组合"] end style B fill:#3498db style E fill:#e74c3c style F fill:#2ecc71 style H fill:#9b59b6
概念说明例子
DataFrame数据载体带Schema的分布式数据集
Transformer转换器,transform()方法VectorAssembler、StandardScaler
Estimator估计器,fit()方法返回ModelLogisticRegression、RandomForest
Model训练好的模型(Transformer)LogisticRegressionModel
Pipeline多个阶段的流水线特征处理 + 模型训练
Param算法参数maxIter、regParam

二、特征工程

2.1 特征工程流程

graph LR subgraph 特征工程流程["🔧 特征工程流程"] A[原始数据] --> B[特征提取
Extraction] B --> C[特征转换
Transformation] C --> D[特征选择
Selection] D --> E[特征向量
Vector] end style A fill:#95a5a6 style E fill:#2ecc71

2.2 常用特征转换器

graph TB subgraph 特征转换器["🛠️ 常用特征转换器"] subgraph 数值处理["数值处理"] N1[VectorAssembler
向量组装] N2[StandardScaler
标准化] N3[MinMaxScaler
归一化] N4[Normalizer
正则化] N5[Bucketizer
分桶] N6[QuantileDiscretizer
分位数分桶] end subgraph 类别处理["类别处理"] C1[StringIndexer
字符串索引化] C2[IndexToString
索引转字符串] C3[OneHotEncoder
独热编码] C4[VectorIndexer
向量索引化] end subgraph 文本处理["文本处理"] T1[Tokenizer
分词] T2[HashingTF
词频哈希] T3[IDF
逆文档频率] T4[Word2Vec
词向量] T5[CountVectorizer
词频统计] end end style N1 fill:#e74c3c style C1 fill:#e74c3c style C3 fill:#e74c3c

2.3 代码实战:特征处理

import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.Vectors

// 示例数据
val data = Seq(
  (0, "male", 25, 50000.0, "Beijing"),
  (1, "female", 30, 60000.0, "Shanghai"),
  (2, "male", 35, 80000.0, "Beijing"),
  (3, "female", 28, 55000.0, "Guangzhou")
).toDF("id", "gender", "age", "salary", "city")

// ========== 1. StringIndexer:字符串 → 数字索引 ==========
val genderIndexer = new StringIndexer()
  .setInputCol("gender")
  .setOutputCol("genderIndex")
  .setHandleInvalid("keep")  // 处理未见过的值

val cityIndexer = new StringIndexer()
  .setInputCol("city")
  .setOutputCol("cityIndex")

// fit + transform
val indexed = genderIndexer.fit(data).transform(data)
val indexed2 = cityIndexer.fit(indexed).transform(indexed)

// ========== 2. OneHotEncoder:独热编码 ==========
val encoder = new OneHotEncoder()
  .setInputCols(Array("genderIndex", "cityIndex"))
  .setOutputCols(Array("genderVec", "cityVec"))
  .setDropLast(false)  // 是否删除最后一个类别

val encoded = encoder.fit(indexed2).transform(indexed2)

// ========== 3. VectorAssembler:组装特征向量 ==========
val assembler = new VectorAssembler()
  .setInputCols(Array("age", "salary", "genderVec", "cityVec"))
  .setOutputCol("features")
  .setHandleInvalid("skip")  // 跳过无效值

val assembled = assembler.transform(encoded)

// ========== 4. StandardScaler:标准化 ==========
val scaler = new StandardScaler()
  .setInputCol("features")
  .setOutputCol("scaledFeatures")
  .setWithMean(true)   // 减去均值
  .setWithStd(true)    // 除以标准差

val scalerModel = scaler.fit(assembled)
val scaled = scalerModel.transform(assembled)

// ========== 5. MinMaxScaler:归一化到[0,1] ==========
val minMaxScaler = new MinMaxScaler()
  .setInputCol("features")
  .setOutputCol("normalizedFeatures")
  .setMin(0.0)
  .setMax(1.0)

val normalized = minMaxScaler.fit(assembled).transform(assembled)

// ========== 6. Bucketizer:自定义分桶 ==========
val ageBucketizer = new Bucketizer()
  .setInputCol("age")
  .setOutputCol("ageBucket")
  .setSplits(Array(Double.NegativeInfinity, 20, 30, 40, Double.PositiveInfinity))

val bucketed = ageBucketizer.transform(data)

// ========== 7. QuantileDiscretizer:分位数分桶 ==========
val discretizer = new QuantileDiscretizer()
  .setInputCol("salary")
  .setOutputCol("salaryBucket")
  .setNumBuckets(4)  // 四分位

val discretized = discretizer.fit(data).transform(data)

2.4 文本特征处理

// 文本数据
val textData = Seq(
  (0, "spark is great for big data"),
  (1, "machine learning with spark"),
  (2, "spark streaming is awesome")
).toDF("id", "text")

// ========== 1. Tokenizer:分词 ==========
val tokenizer = new Tokenizer()
  .setInputCol("text")
  .setOutputCol("words")

val tokenized = tokenizer.transform(textData)
// words: ["spark", "is", "great", "for", "big", "data"]

// ========== 2. StopWordsRemover:去停用词 ==========
val remover = new StopWordsRemover()
  .setInputCol("words")
  .setOutputCol("filtered")
  .setStopWords(Array("is", "for", "with"))

val filtered = remover.transform(tokenized)

// ========== 3. HashingTF + IDF:TF-IDF ==========
val hashingTF = new HashingTF()
  .setInputCol("filtered")
  .setOutputCol("rawFeatures")
  .setNumFeatures(1000)

val featurized = hashingTF.transform(filtered)

val idf = new IDF()
  .setInputCol("rawFeatures")
  .setOutputCol("tfidfFeatures")

val tfidf = idf.fit(featurized).transform(featurized)

// ========== 4. Word2Vec:词向量 ==========
val word2Vec = new Word2Vec()
  .setInputCol("filtered")
  .setOutputCol("word2vecFeatures")
  .setVectorSize(100)
  .setMinCount(1)
  .setMaxIter(10)

val word2VecModel = word2Vec.fit(filtered)
val word2VecResult = word2VecModel.transform(filtered)

// 查找相似词
word2VecModel.findSynonyms("spark", 5).show()

// ========== 5. CountVectorizer:词频统计 ==========
val countVectorizer = new CountVectorizer()
  .setInputCol("filtered")
  .setOutputCol("countFeatures")
  .setVocabSize(1000)
  .setMinDF(1)

val cvModel = countVectorizer.fit(filtered)
val countResult = cvModel.transform(filtered)

// 查看词汇表
println(cvModel.vocabulary.mkString(", "))

三、Pipeline:机器学习流水线

3.1 Pipeline概念

graph LR subgraph Pipeline流程["🔗 Pipeline流程"] A[原始数据] --> B[Stage 1
StringIndexer] B --> C[Stage 2
OneHotEncoder] C --> D[Stage 3
VectorAssembler] D --> E[Stage 4
StandardScaler] E --> F[Stage 5
LogisticRegression] F --> G[预测结果] end style B fill:#3498db style C fill:#3498db style D fill:#3498db style E fill:#3498db style F fill:#e74c3c

3.2 Pipeline代码

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression

// 准备数据
val data = spark.read.parquet("training_data.parquet")
val Array(training, test) = data.randomSplit(Array(0.8, 0.2), seed = 42)

// 定义各个Stage
val genderIndexer = new StringIndexer()
  .setInputCol("gender")
  .setOutputCol("genderIndex")

val cityIndexer = new StringIndexer()
  .setInputCol("city")
  .setOutputCol("cityIndex")

val encoder = new OneHotEncoder()
  .setInputCols(Array("genderIndex", "cityIndex"))
  .setOutputCols(Array("genderVec", "cityVec"))

val assembler = new VectorAssembler()
  .setInputCols(Array("age", "salary", "genderVec", "cityVec"))
  .setOutputCol("features")

val scaler = new StandardScaler()
  .setInputCol("features")
  .setOutputCol("scaledFeatures")

val lr = new LogisticRegression()
  .setFeaturesCol("scaledFeatures")
  .setLabelCol("label")
  .setMaxIter(100)
  .setRegParam(0.01)

// 构建Pipeline
val pipeline = new Pipeline()
  .setStages(Array(
    genderIndexer,
    cityIndexer,
    encoder,
    assembler,
    scaler,
    lr
  ))

// 训练Pipeline
val pipelineModel = pipeline.fit(training)

// 预测
val predictions = pipelineModel.transform(test)

// 保存Pipeline模型
pipelineModel.write.overwrite().save("path/to/pipeline_model")

// 加载Pipeline模型
val loadedModel = PipelineModel.load("path/to/pipeline_model")

3.3 Pipeline优势

graph TB subgraph Pipeline优势["✨ Pipeline优势"] A["代码整洁
所有步骤串联"] B["防止数据泄露
fit只在训练集上"] C["易于保存和加载
整体序列化"] D["支持交叉验证
与CrossValidator集成"] E["易于复现
完整的处理流程"] end style A fill:#2ecc71 style B fill:#2ecc71 style C fill:#2ecc71

四、分类算法

4.1 支持的分类算法

graph TB subgraph 分类算法["📊 MLlib分类算法"] subgraph 线性模型["线性模型"] L1[LogisticRegression
逻辑回归] L2[LinearSVC
线性SVM] end subgraph 树模型["树模型"] T1[DecisionTreeClassifier
决策树] T2[RandomForestClassifier
随机森林] T3[GBTClassifier
梯度提升树] end subgraph 其他["其他"] O1[NaiveBayes
朴素贝叶斯] O2[MultilayerPerceptronClassifier
多层感知机] end end style T2 fill:#e74c3c style T3 fill:#e74c3c

4.2 逻辑回归

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}

// 准备数据(需要features和label列)
val data = spark.read.format("libsvm").load("sample_libsvm_data.txt")
val Array(training, test) = data.randomSplit(Array(0.7, 0.3))

// 创建逻辑回归模型
val lr = new LogisticRegression()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setMaxIter(100)              // 最大迭代次数
  .setRegParam(0.01)            // 正则化参数(L2)
  .setElasticNetParam(0.0)      // 弹性网络参数(0=L2, 1=L1)
  .setThreshold(0.5)            // 分类阈值
  .setFamily("binomial")        // binomial(二分类)或multinomial(多分类)

// 训练
val lrModel = lr.fit(training)

// 查看模型系数
println(s"Coefficients: ${lrModel.coefficients}")
println(s"Intercept: ${lrModel.intercept}")

// 预测
val predictions = lrModel.transform(test)
predictions.select("label", "prediction", "probability").show(10)

// 评估(二分类)
val binaryEvaluator = new BinaryClassificationEvaluator()
  .setLabelCol("label")
  .setRawPredictionCol("rawPrediction")
  .setMetricName("areaUnderROC")  // 或areaUnderPR

val auc = binaryEvaluator.evaluate(predictions)
println(s"AUC: $auc")

// 评估(多分类)
val multiEvaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")  // accuracy, f1, weightedPrecision, weightedRecall

val accuracy = multiEvaluator.evaluate(predictions)
println(s"Accuracy: $accuracy")

// 训练过程摘要
val trainingSummary = lrModel.binarySummary
println(s"Area Under ROC: ${trainingSummary.areaUnderROC}")
trainingSummary.roc.show()  // ROC曲线数据
trainingSummary.pr.show()   // PR曲线数据

4.3 随机森林

import org.apache.spark.ml.classification.{RandomForestClassifier, RandomForestClassificationModel}

// 创建随机森林分类器
val rf = new RandomForestClassifier()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setNumTrees(100)             // 树的数量
  .setMaxDepth(10)              // 最大深度
  .setMaxBins(32)               // 最大分箱数
  .setMinInstancesPerNode(1)    // 叶子节点最小样本数
  .setMinInfoGain(0.0)          // 最小信息增益
  .setSubsamplingRate(0.8)      // 样本采样率
  .setFeatureSubsetStrategy("sqrt")  // 特征采样:auto, all, sqrt, log2, onethird
  .setSeed(42)

// 训练
val rfModel = rf.fit(training)

// 查看特征重要性
println(s"Feature Importances: ${rfModel.featureImportances}")

// 查看决策树
rfModel.trees.foreach(tree => println(tree.toDebugString))

// 预测
val predictions = rfModel.transform(test)

// 评估
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")

println(s"Accuracy: ${evaluator.evaluate(predictions)}")

4.4 梯度提升树(GBT)

import org.apache.spark.ml.classification.{GBTClassifier, GBTClassificationModel}

// 创建GBT分类器(只支持二分类)
val gbt = new GBTClassifier()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setMaxIter(100)              // 迭代次数(树的数量)
  .setMaxDepth(5)               // 最大深度
  .setStepSize(0.1)             // 学习率
  .setSubsamplingRate(0.8)      // 样本采样率
  .setFeatureSubsetStrategy("sqrt")
  .setSeed(42)

// 训练
val gbtModel = gbt.fit(training)

// 查看特征重要性
println(s"Feature Importances: ${gbtModel.featureImportances}")

// 预测
val predictions = gbtModel.transform(test)

4.5 多层感知机(MLP)

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier

// 定义网络层
// 输入层神经元数 = 特征数
// 输出层神经元数 = 类别数
val layers = Array[Int](4, 10, 8, 3)  // 4输入 -> 10隐藏 -> 8隐藏 -> 3输出

val mlp = new MultilayerPerceptronClassifier()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setLayers(layers)
  .setMaxIter(100)
  .setBlockSize(128)  // 批次大小
  .setSeed(42)

val mlpModel = mlp.fit(training)
val predictions = mlpModel.transform(test)

五、回归算法

5.1 支持的回归算法

graph TB subgraph 回归算法["📈 MLlib回归算法"] R1[LinearRegression
线性回归] R2[DecisionTreeRegressor
决策树回归] R3[RandomForestRegressor
随机森林回归] R4[GBTRegressor
梯度提升回归] R5[GeneralizedLinearRegression
广义线性回归] R6[IsotonicRegression
保序回归] end style R1 fill:#3498db style R3 fill:#e74c3c style R4 fill:#e74c3c

5.2 线性回归

import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.evaluation.RegressionEvaluator

// 创建线性回归模型
val lr = new LinearRegression()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setMaxIter(100)
  .setRegParam(0.01)            // L2正则化
  .setElasticNetParam(0.0)      // 弹性网络(0=L2, 1=L1)
  .setStandardization(true)     // 是否标准化
  .setSolver("auto")            // auto, normal, l-bfgs

// 训练
val lrModel = lr.fit(training)

// 查看模型参数
println(s"Coefficients: ${lrModel.coefficients}")
println(s"Intercept: ${lrModel.intercept}")

// 训练摘要
val trainingSummary = lrModel.summary
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"R2: ${trainingSummary.r2}")
println(s"MAE: ${trainingSummary.meanAbsoluteError}")

// 预测
val predictions = lrModel.transform(test)

// 评估
val evaluator = new RegressionEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("rmse")  // rmse, mse, mae, r2

val rmse = evaluator.evaluate(predictions)
println(s"Test RMSE: $rmse")

5.3 随机森林回归

import org.apache.spark.ml.regression.{RandomForestRegressor, RandomForestRegressionModel}

val rf = new RandomForestRegressor()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setNumTrees(100)
  .setMaxDepth(10)
  .setMinInstancesPerNode(1)
  .setSubsamplingRate(0.8)
  .setFeatureSubsetStrategy("sqrt")
  .setSeed(42)

val rfModel = rf.fit(training)

println(s"Feature Importances: ${rfModel.featureImportances}")

val predictions = rfModel.transform(test)

5.4 GBT回归

import org.apache.spark.ml.regression.{GBTRegressor, GBTRegressionModel}

val gbt = new GBTRegressor()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setMaxIter(100)
  .setMaxDepth(5)
  .setStepSize(0.1)
  .setSubsamplingRate(0.8)
  .setLossType("squared")  // squared, absolute
  .setSeed(42)

val gbtModel = gbt.fit(training)

val predictions = gbtModel.transform(test)

六、聚类算法

6.1 K-Means聚类

import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.evaluation.ClusteringEvaluator

// 创建K-Means模型
val kmeans = new KMeans()
  .setFeaturesCol("features")
  .setPredictionCol("cluster")
  .setK(5)                      // 聚类数
  .setMaxIter(100)
  .setInitMode("k-means||")     // k-means||(默认)或random
  .setSeed(42)

// 训练
val model = kmeans.fit(data)

// 查看聚类中心
println("Cluster Centers:")
model.clusterCenters.foreach(println)

// 预测
val predictions = model.transform(data)
predictions.select("features", "cluster").show()

// 评估(轮廓系数)
val evaluator = new ClusteringEvaluator()
  .setFeaturesCol("features")
  .setPredictionCol("cluster")
  .setMetricName("silhouette")  // silhouette

val silhouette = evaluator.evaluate(predictions)
println(s"Silhouette Score: $silhouette")

// 计算WSSSE(Within Set Sum of Squared Errors)
val wssse = model.summary.trainingCost
println(s"WSSSE: $wssse")

6.2 选择最佳K值

// 肘部法则:计算不同K值的WSSSE
val kValues = (2 to 10).toArray
val wssseValues = kValues.map { k =>
  val kmeans = new KMeans()
    .setFeaturesCol("features")
    .setK(k)
    .setMaxIter(100)
    .setSeed(42)
  val model = kmeans.fit(data)
  (k, model.summary.trainingCost)
}

wssseValues.foreach { case (k, wssse) =>
  println(s"K=$k, WSSSE=$wssse")
}

// 可视化:选择"肘部"位置的K值

6.3 高斯混合模型(GMM)

import org.apache.spark.ml.clustering.{GaussianMixture, GaussianMixtureModel}

val gmm = new GaussianMixture()
  .setFeaturesCol("features")
  .setPredictionCol("cluster")
  .setProbabilityCol("probability")
  .setK(3)
  .setMaxIter(100)
  .setSeed(42)

val model = gmm.fit(data)

// 查看高斯分布参数
for (i <- 0 until model.getK) {
  println(s"Cluster $i:")
  println(s"  Weight: ${model.weights(i)}")
  println(s"  Mean: ${model.gaussians(i).mean}")
  println(s"  Covariance: ${model.gaussians(i).cov}")
}

val predictions = model.transform(data)
predictions.select("features", "cluster", "probability").show()

6.4 LDA主题模型

import org.apache.spark.ml.clustering.{LDA, LDAModel}

// 假设已经有词频向量
val lda = new LDA()
  .setFeaturesCol("features")  // 词频向量
  .setK(10)                    // 主题数
  .setMaxIter(100)
  .setOptimizer("online")      // online或em
  .setSeed(42)

val model = lda.fit(data)

// 查看主题
val topics = model.describeTopics(10)  // 每个主题的前10个词
topics.show(false)

// 主题分布
val transformed = model.transform(data)
transformed.select("topicDistribution").show(false)

// 模型困惑度(越低越好)
val perplexity = model.logPerplexity(data)
println(s"Perplexity: $perplexity")

七、推荐系统:ALS

7.1 ALS原理

graph TB subgraph ALS协同过滤["🎬 ALS协同过滤"] A[用户-物品评分矩阵] --> B[矩阵分解] B --> C[用户特征矩阵 U] B --> D[物品特征矩阵 V] C --> E["预测评分 = U × V^T"] D --> E end style A fill:#95a5a6 style C fill:#3498db style D fill:#e74c3c style E fill:#2ecc71

7.2 ALS实战

import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.evaluation.RegressionEvaluator

// 准备评分数据
// 格式:userId, itemId, rating
val ratings = spark.read
  .option("header", "true")
  .csv("ratings.csv")
  .selectExpr(
    "cast(userId as int) as userId",
    "cast(movieId as int) as movieId",
    "cast(rating as float) as rating"
  )

val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

// 创建ALS模型
val als = new ALS()
  .setUserCol("userId")
  .setItemCol("movieId")
  .setRatingCol("rating")
  .setRank(50)                  // 隐因子数(特征维度)
  .setMaxIter(20)
  .setRegParam(0.1)             // 正则化参数
  .setAlpha(1.0)                // 隐式反馈的置信度参数
  .setImplicitPrefs(false)      // 是否使用隐式反馈
  .setColdStartStrategy("drop") // 处理冷启动:drop或nan
  .setSeed(42)

// 训练
val alsModel = als.fit(training)

// 预测
val predictions = alsModel.transform(test)

// 评估
val evaluator = new RegressionEvaluator()
  .setLabelCol("rating")
  .setPredictionCol("prediction")
  .setMetricName("rmse")

val rmse = evaluator.evaluate(predictions)
println(s"RMSE: $rmse")

// 为用户推荐Top N物品
val userRecs = alsModel.recommendForAllUsers(10)
userRecs.show(false)

// 为物品推荐Top N用户
val itemRecs = alsModel.recommendForAllItems(10)

// 为特定用户推荐
val userSubset = ratings.select("userId").distinct().limit(3)
val userSubsetRecs = alsModel.recommendForUserSubset(userSubset, 10)

// 为特定物品推荐
val itemSubset = ratings.select("movieId").distinct().limit(3)
val itemSubsetRecs = alsModel.recommendForItemSubset(itemSubset, 10)

7.3 推荐结果处理

// 推荐结果是数组格式,需要展开
import org.apache.spark.sql.functions._

// 原始格式:userId, recommendations: [{movieId, rating}, ...]
// 展开为:userId, movieId, rating

val flatRecs = userRecs
  .select(
    col("userId"),
    explode(col("recommendations")).as("rec")
  )
  .select(
    col("userId"),
    col("rec.movieId"),
    col("rec.rating")
  )

flatRecs.show()

// 关联物品信息
val movies = spark.read.option("header", "true").csv("movies.csv")

val recsWithInfo = flatRecs
  .join(movies, Seq("movieId"))
  .select("userId", "movieId", "title", "rating")
  .orderBy(col("userId"), col("rating").desc)

recsWithInfo.show()

八、模型选择与调优

8.1 交叉验证

graph TB subgraph 交叉验证["🔄 K折交叉验证"] A[数据集] --> B[分成K份] B --> C1[Fold 1: 验证集] B --> C2[Fold 2] B --> C3[Fold 3] B --> C4[Fold K] C1 --> D["轮流作为验证集
其他作为训练集"] C2 --> D C3 --> D C4 --> D D --> E[计算平均指标] end style E fill:#2ecc71

8.2 CrossValidator

import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

// 创建模型
val lr = new LogisticRegression()
  .setFeaturesCol("features")
  .setLabelCol("label")

// 构建参数网格
val paramGrid = new ParamGridBuilder()
  .addGrid(lr.maxIter, Array(50, 100, 200))
  .addGrid(lr.regParam, Array(0.001, 0.01, 0.1))
  .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
  .build()

// 创建评估器
val evaluator = new BinaryClassificationEvaluator()
  .setLabelCol("label")
  .setMetricName("areaUnderROC")

// 创建交叉验证器
val cv = new CrossValidator()
  .setEstimator(lr)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(5)               // K折
  .setParallelism(4)            // 并行度
  .setSeed(42)

// 训练(会尝试所有参数组合)
val cvModel = cv.fit(training)

// 查看最佳参数
println(s"Best Params: ${cvModel.bestModel.extractParamMap()}")

// 查看所有参数组合的平均指标
val avgMetrics = cvModel.avgMetrics
paramGrid.zip(avgMetrics).foreach { case (params, metric) =>
  println(s"$params => $metric")
}

// 使用最佳模型预测
val predictions = cvModel.transform(test)

8.3 TrainValidationSplit

比交叉验证更快,但不如交叉验证稳定。

import org.apache.spark.ml.tuning.TrainValidationSplit

val tvs = new TrainValidationSplit()
  .setEstimator(lr)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setTrainRatio(0.8)  // 80%训练,20%验证
  .setParallelism(4)
  .setSeed(42)

val tvsModel = tvs.fit(training)
val predictions = tvsModel.transform(test)

8.4 Pipeline + CrossValidator

// 将Pipeline和CrossValidator结合使用
val pipeline = new Pipeline()
  .setStages(Array(indexer, encoder, assembler, scaler, lr))

// 参数网格(引用Pipeline中的Stage)
val paramGrid = new ParamGridBuilder()
  .addGrid(lr.maxIter, Array(50, 100))
  .addGrid(lr.regParam, Array(0.01, 0.1))
  .addGrid(scaler.withMean, Array(true, false))
  .build()

val cv = new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(5)

val cvModel = cv.fit(training)

九、模型持久化

9.1 保存和加载模型

// 保存模型
lrModel.write.overwrite().save("path/to/lr_model")

// 加载模型
val loadedModel = LogisticRegressionModel.load("path/to/lr_model")

// 保存Pipeline模型
pipelineModel.write.overwrite().save("path/to/pipeline_model")

// 加载Pipeline模型
val loadedPipeline = PipelineModel.load("path/to/pipeline_model")

// 保存CrossValidator模型
cvModel.write.overwrite().save("path/to/cv_model")

// 加载CrossValidator模型
val loadedCvModel = CrossValidatorModel.load("path/to/cv_model")

9.2 模型导出为PMML/ONNX

// 导出为PMML(需要额外依赖)
// import org.jpmml.sparkml.PMMLBuilder

// val pmml = new PMMLBuilder(schema, pipelineModel).build()
// JAXBUtil.marshalPMML(pmml, new FileOutputStream("model.pmml"))

// 或者导出为ONNX
// 需要使用第三方工具如onnxmltools

十、实战案例

10.1 案例一:用户流失预测

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.feature._
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}

// 1. 加载数据
val data = spark.read
  .option("header", "true")
  .option("inferSchema", "true")
  .csv("churn_data.csv")

// 2. 数据探索
data.printSchema()
data.describe().show()
data.groupBy("churn").count().show()

// 3. 特征工程
val categoricalCols = Array("gender", "partner", "dependents", "phoneService", 
  "internetService", "contract", "paymentMethod")
val numericCols = Array("tenure", "monthlyCharges", "totalCharges")

// 字符串索引化
val indexers = categoricalCols.map { col =>
  new StringIndexer()
    .setInputCol(col)
    .setOutputCol(s"${col}Index")
    .setHandleInvalid("keep")
}

// 独热编码
val encoder = new OneHotEncoder()
  .setInputCols(categoricalCols.map(_ + "Index"))
  .setOutputCols(categoricalCols.map(_ + "Vec"))

// 组装特征
val assembler = new VectorAssembler()
  .setInputCols(numericCols ++ categoricalCols.map(_ + "Vec"))
  .setOutputCol("features")
  .setHandleInvalid("skip")

// 标签处理
val labelIndexer = new StringIndexer()
  .setInputCol("churn")
  .setOutputCol("label")

// 4. 模型
val rf = new RandomForestClassifier()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setNumTrees(100)

// 5. 构建Pipeline
val pipeline = new Pipeline()
  .setStages(indexers ++ Array(encoder, assembler, labelIndexer, rf))

// 6. 划分数据
val Array(training, test) = data.randomSplit(Array(0.8, 0.2), seed = 42)

// 7. 交叉验证调参
val paramGrid = new ParamGridBuilder()
  .addGrid(rf.numTrees, Array(50, 100, 200))
  .addGrid(rf.maxDepth, Array(5, 10, 15))
  .build()

val evaluator = new BinaryClassificationEvaluator()
  .setLabelCol("label")
  .setMetricName("areaUnderROC")

val cv = new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(5)
  .setParallelism(4)

// 8. 训练
val cvModel = cv.fit(training)

// 9. 评估
val predictions = cvModel.transform(test)

val auc = evaluator.evaluate(predictions)
println(s"Test AUC: $auc")

// 混淆矩阵
predictions.groupBy("label", "prediction").count().show()

// 10. 特征重要性
val rfModel = cvModel.bestModel.asInstanceOf[PipelineModel]
  .stages.last.asInstanceOf[RandomForestClassificationModel]
println(s"Feature Importances: ${rfModel.featureImportances}")

// 11. 保存模型
cvModel.bestModel.write.overwrite().save("churn_model")

10.2 案例二:房价预测

import org.apache.spark.ml.regression.GBTRegressor
import org.apache.spark.ml.evaluation.RegressionEvaluator

// 1. 加载数据
val data = spark.read
  .option("header", "true")
  .option("inferSchema", "true")
  .csv("house_prices.csv")

// 2. 特征处理(简化版)
val featureCols = Array("bedrooms", "bathrooms", "sqft_living", "sqft_lot", 
  "floors", "waterfront", "view", "condition", "grade", "yr_built")

val assembler = new VectorAssembler()
  .setInputCols(featureCols)
  .setOutputCol("features")

val assembled = assembler.transform(data)
  .select("features", "price")
  .withColumnRenamed("price", "label")

// 3. 划分数据
val Array(training, test) = assembled.randomSplit(Array(0.8, 0.2))

// 4. GBT回归
val gbt = new GBTRegressor()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setMaxIter(100)
  .setMaxDepth(5)
  .setStepSize(0.1)

// 5. 训练
val model = gbt.fit(training)

// 6. 预测
val predictions = model.transform(test)

// 7. 评估
val evaluator = new RegressionEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")

println(s"RMSE: ${evaluator.setMetricName("rmse").evaluate(predictions)}")
println(s"MAE: ${evaluator.setMetricName("mae").evaluate(predictions)}")
println(s"R2: ${evaluator.setMetricName("r2").evaluate(predictions)}")

// 8. 特征重要性
println(s"Feature Importances: ${model.featureImportances}")
featureCols.zip(model.featureImportances.toArray).sortBy(-_._2).foreach {
  case (feature, importance) => println(s"$feature: $importance")
}

10.3 案例三:电影推荐系统

import org.apache.spark.ml.recommendation.ALS

// 1. 加载评分数据
val ratings = spark.read
  .option("header", "true")
  .csv("ratings.csv")
  .selectExpr(
    "cast(userId as int)",
    "cast(movieId as int)",
    "cast(rating as float)",
    "timestamp"
  )

val movies = spark.read
  .option("header", "true")
  .csv("movies.csv")

// 2. 数据探索
println(s"用户数: ${ratings.select("userId").distinct().count()}")
println(s"电影数: ${ratings.select("movieId").distinct().count()}")
println(s"评分数: ${ratings.count()}")

// 3. 划分数据
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

// 4. ALS模型
val als = new ALS()
  .setUserCol("userId")
  .setItemCol("movieId")
  .setRatingCol("rating")
  .setRank(50)
  .setMaxIter(20)
  .setRegParam(0.1)
  .setColdStartStrategy("drop")

// 5. 训练
val model = als.fit(training)

// 6. 评估
val predictions = model.transform(test)
val evaluator = new RegressionEvaluator()
  .setLabelCol("rating")
  .setPredictionCol("prediction")
  .setMetricName("rmse")

println(s"RMSE: ${evaluator.evaluate(predictions)}")

// 7. 为所有用户推荐Top 10电影
val userRecs = model.recommendForAllUsers(10)

// 8. 展示某用户的推荐
val userId = 1
val userRecommendations = userRecs
  .filter(col("userId") === userId)
  .select(explode(col("recommendations")).as("rec"))
  .select(col("rec.movieId"), col("rec.rating"))
  .join(movies, Seq("movieId"))
  .select("title", "genres", "rating")

println(s"为用户 $userId 推荐的电影:")
userRecommendations.show(false)

// 9. 保存模型
model.write.overwrite().save("movie_recommendation_model")

十一、性能调优

11.1 数据并行度

// 增加分区数
val data = spark.read.parquet("data").repartition(200)

// 设置默认并行度
spark.conf.set("spark.default.parallelism", "200")
spark.conf.set("spark.sql.shuffle.partitions", "200")

11.2 缓存中间结果

// 缓存训练数据
training.cache()
training.count()  // 触发缓存

// 训练完成后释放
training.unpersist()

11.3 调整模型参数

// 随机森林:减少树的数量和深度可以加速
val rf = new RandomForestClassifier()
  .setNumTrees(50)    // 减少树数量
  .setMaxDepth(5)     // 减少深度
  .setSubsamplingRate(0.5)  // 减少采样率

11.4 使用近似算法

// K-Means使用||初始化(更快)
val kmeans = new KMeans()
  .setInitMode("k-means||")  // 比random更好更快

// 逻辑回归使用LBFGS(大数据更快)
val lr = new LogisticRegression()
  .setSolver("l-bfgs")

十二、写在最后

Spark MLlib让大规模机器学习成为可能,但它也有自己的局限:

优点

  • 分布式,能处理超大数据
  • 与Spark生态无缝集成
  • Pipeline API很优雅
  • 支持交叉验证和参数调优

局限

  • 算法没有sklearn丰富
  • 深度学习支持有限
  • 调参不如专门的AutoML工具
  • 小数据量反而更慢

使用建议

  • 数据量小(< 1GB):用sklearn
  • 数据量大(> 10GB):用Spark MLlib
  • 深度学习:用TensorFlow/PyTorch,Spark只做数据处理
  • 特征工程:Spark MLlib很好用

最后送大家一句话:

模型只是手段,理解业务才是目的。再好的算法,也抵不过对数据的深入理解。

本文作者:一个从sklearn转型到MLlib的数据工程师

最惨经历:训练了3天的模型,发现特征工程错了

系列完结,感谢阅读!


附录:面试高频题

  1. Spark MLlib和sklearn的区别?

    MLlib基于分布式计算,适合大数据;sklearn单机运行,适合小数据。MLlib API是Transformer/Estimator/Pipeline模式。
  2. 什么是Pipeline?有什么好处?

    Pipeline是多个Transformer和Estimator的串联。好处:代码整洁、防止数据泄露、易于保存和复用、支持交叉验证。
  3. ALS推荐算法的原理?

    将用户-物品评分矩阵分解为用户特征矩阵和物品特征矩阵的乘积。通过交替最小二乘法优化。
  4. 如何处理类别特征?

    StringIndexer转为数字索引,然后OneHotEncoder进行独热编码,最后VectorAssembler组装成特征向量。
  5. CrossValidator和TrainValidationSplit的区别?

    CrossValidator是K折交叉验证,更稳定但更慢;TrainValidationSplit只划分一次训练集和验证集,更快但方差更大。
  6. 如何解决MLlib中的数据倾斜?

    增加分区数、采样处理、特征分桶、使用近似算法。
评论区
暂无评论
avatar