Via: http://stackoverflow.com/questions/32331848/create-a-custom-transformer-in-pyspark-ml
import nltk
from pyspark import keyword_only ## < 2.0 -> pyspark.ml.util.keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType
class NLTKWordPunctTokenizer(Transformer, HasInputCol, HasOutputCol):
@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopwords=None):
super(NLTKWordPunctTokenizer, self).__init__()
self.stopwords = Param(self, "stopwords", "")
self._setDefault(stopwords=set())
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None, stopwords=None):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def setStopwords(self, value):
self._paramMap[self.stopwords] = value
return self
def getStopwords(self):
return self.getOrDefault(self.stopwords)
def _transform(self, dataset):
stopwords = self.getStopwords()
def f(s):
tokens = nltk.tokenize.wordpunct_tokenize(s)
return [t for t in tokens if t.lower() not in stopwords]
t = ArrayType(StringType())
out_col = self.getOutputCol()
in_col = dataset[self.getInputCol()]
return dataset.withColumn(out_col, udf(f, t)(in_col))
Comments
Post a Comment