Spark的抽取、转换和选择操作
特征抽取
TF-IDF
CountVectorizer
CountVectorizer
的目的是为了将文档集合转换成关于词语数量的向量。当没有预先定义词典的时候,CountVectorizer
可以作为一个Estimator来抽取词汇,并产生CountVectorizerModel
。该模型将文档表示成稀疏向量的形式,其结果可以作为参数传递给其他算法,例如LDA。
在模型的适配阶段,CountVectorizer
将根据词频选择最频繁出现的词语(vocabSize
)。可选参数minDF指的就是单词在文档中最小出现次数。另一个可选的二元参数是控制输出向量的,如果为true
,那么非零的结果会变成1。这个是用于那些只需要0-1向量的模型中。
例子
id | texts |
---|---|
0 | Array(“a”,”b”,”c”) |
1 | Array(“a”,”b”,”b”,”c”,”a”) |
文本中每一行是一个字符数组,调用CountVectorizer
会产生词汇为(a,b,c)
的CountVectorizerModel
。其输出结果包含了一个vector
列:
id | texts | vector |
---|---|---|
0 | Array(“a”,”b”,”c”) | (3,[0,1,2],[1.0,1.0,1.0]) |
1 | Array(“a”,”b”,”b”,”c”,”a”) | (3,[0,1,2],[2.0,2.0,1.0]) |
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
// Input data: Each row is a bag of words from a sentence or document.
List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("a", "b", "c")),
RowFactory.create(Arrays.asList("a", "b", "b", "c", "a"))
);
StructType schema = new StructType(new StructField [] {
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
Dataset<Row> df = spark.createDataFrame(data, schema);
// fit a CountVectorizerModel from the corpus
CountVectorizerModel cvModel = new CountVectorizer()
.setInputCol("text")
.setOutputCol("feature")
.setVocabSize(3)
.setMinDF(2)
.fit(df);
// alternatively, define CountVectorizerModel with a-priori vocabulary
CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"})
.setInputCol("text")
.setOutputCol("feature");
cvModel.transform(df).show(false);
特征转换
StringIndexer
StringIndexer
是用来将字符类型的列下的标签转换成索引的转换器,索引的范围是[0, numLabels],按照标签出现的频率排序。因此,最频繁出现的标签的索引是0。如果用户选择保留宝课件的标签在索引中,其索引可以跟在numLabels后面。如果输入的列是数字类型,我们将把它们转换成字符,并对字符串的值(string-indexed)进行索引。当流(Pipeline)组件用到了这个字符值的标签的时候,你必须将列设置成string-indexed。很多情况下,你都需要使用setInputCol来设置输入列。
例子
id | category |
---|---|
0 | a |
1 | b |
2 | c |
3 | a |
4 | a |
5 | c |
最后我们可以得到一个新列categoryIndex
:
id | category | categoryIndex |
---|---|---|
0 | a | 0.0 |
1 | b | 2.0 |
2 | c | 1.0 |
3 | a | 0.0 |
4 | a | 0.0 |
5 | c | 1.0 |
a
的索引是0
的原因是因为0出现的次数最多。
当StringIndexer
遇到未见过的标签的时候,有三证策略来处理:
- 抛出异常(默认)
- 跳过一整行数据
- 将未见过的标签作为一个特殊的存储,放在numLabels的索引处
例子
id | category |
---|---|
0 | a |
1 | b |
2 | c |
3 | d |
4 | e |
如果你没有设置如何处理未见过的标签,那么StringIndexer
将会抛出异常,如果你设置了方法setHandleInvalid("skip")
,将会产生如下数据集:
id | category | categoryIndex |
---|---|---|
0 | a | 0.0 |
1 | b | 2.0 |
2 | c | 1.0 |
如果行包含了d
或者e
,那么这些行将不会出现。
如果你设置了setHandleInvalid("keep")
,那么会产生如下结果:
id | category | categoryIndex |
---|---|---|
0 | a | 0.0 |
1 | b | 2.0 |
2 | c | 1.0 |
3 | d | 3.0 |
4 | e | 3.0 |
也就是说,d
和e
被映射成了索引3.0
。
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
List<Row> data = Arrays.asList(
RowFactory.create(0, "a"),
RowFactory.create(1, "b"),
RowFactory.create(2, "c"),
RowFactory.create(3, "a"),
RowFactory.create(4, "a"),
RowFactory.create(5, "c")
);
StructType schema = new StructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("category", StringType, false)
});
Dataset<Row> df = spark.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex");
Dataset<Row> indexed = indexer.fit(df).transform(df);
indexed.show();
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
