#Databricks with Python demo

This notebook is based on the following offical databricks notebook: https://docs.databricks.com/spark/latest/mllib/binary-classification-mllib-pipelines.html

## Getting and transforming the data

In this demo, we will use the adult data, which is a sample dataset included in databricks. We will read in the data in SQL using the CSV data source for Spark.

In [3]:
%sql
CREATE TABLE adult (
  age DOUBLE,
  workclass STRING,
  fnlwgt DOUBLE,
  education STRING,
  education_num DOUBLE,
  marital_status STRING,
  occupation STRING,
  relationship STRING,
  race STRING,
  sex STRING,
  capital_gain DOUBLE,
  capital_loss DOUBLE,
  hours_per_week DOUBLE,
  native_country STRING,
  income STRING)
USING CSV
OPTIONS (path "/databricks-datasets/adult/adult.data", header "true")

In [4]:
sparkdf = spark.table("adult")
display(sparkdf)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
50.0,Self-emp-not-inc,83311.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,13.0,United-States,<=50K
38.0,Private,215646.0,HS-grad,9.0,Divorced,Handlers-cleaners,Not-in-family,White,Male,0.0,0.0,40.0,United-States,<=50K
53.0,Private,234721.0,11th,7.0,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0.0,0.0,40.0,United-States,<=50K
28.0,Private,338409.0,Bachelors,13.0,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0.0,0.0,40.0,Cuba,<=50K
37.0,Private,284582.0,Masters,14.0,Married-civ-spouse,Exec-managerial,Wife,White,Female,0.0,0.0,40.0,United-States,<=50K
49.0,Private,160187.0,9th,5.0,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0.0,0.0,16.0,Jamaica,<=50K
52.0,Self-emp-not-inc,209642.0,HS-grad,9.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,45.0,United-States,>50K
31.0,Private,45781.0,Masters,14.0,Never-married,Prof-specialty,Not-in-family,White,Female,14084.0,0.0,50.0,United-States,>50K
42.0,Private,159449.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178.0,0.0,40.0,United-States,>50K
37.0,Private,280464.0,Some-college,10.0,Married-civ-spouse,Exec-managerial,Husband,Black,Male,0.0,0.0,80.0,United-States,>50K


We should perform a number of transformations before we can fit a machine learning model to our data. We will group these transformations in a pipeline. 

First, we should transform the categorical variables in our data set, as many machine learning algorithms can't handle categorical features. This can be done using one hot encoding, where we will basically create a new column for each category of the categorocal variable. 

We will use a combination of StringIndexer (to give every unique value in a column an index) and OneHotEncoderEstimator (to map the feature to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values) to do so.

In [6]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler

#These are all categorical columns that need to be transformed
categoricalColumns = ["workclass", "education", "marital_status", "occupation", "relationship", "race", "sex", "native_country"]

stages = [] # Here we will add the stages in our Pipeline
for categoricalCol in categoricalColumns:# Repeat this loop for each categorical column
    #Category Indexing with StringIndexer
    stringIndexer = StringIndexer(inputCol=categoricalCol, outputCol=categoricalCol + "Index")
    # Use OneHotEncoder to convert categorical variables into binary SparseVectors
    encoder = OneHotEncoderEstimator(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol + "classVec"])
    # Add stages to the pipeline. These are not run here, but will run all at once later on.
    stages += [stringIndexer, encoder]
    

Next, we should also encode our labels to label indices as well. We can again use the StringIndexer to do so.

In [8]:
# Convert label into label indices using the StringIndexer
label_stringIdx = StringIndexer(inputCol="income", outputCol="label")
stages += [label_stringIdx]

Next, we should merge all feature columns into a single vector column. This is needed for machine learning algorithms to be applied.

In [10]:
# Transform all features into a vector using VectorAssembler
numericCols = ["age", "fnlwgt", "education_num", "capital_gain", "capital_loss", "hours_per_week"]
assemblerInputs = [c + "classVec" for c in categoricalColumns] + numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
stages += [assembler]

Now, let's chain all the transformations together to specify our workflow in a pipeline

In [12]:
partialPipeline = Pipeline().setStages(stages)
pipelineModel = partialPipeline.fit(sparkdf)
preppedDataDF = pipelineModel.transform(sparkdf)

The only input data we need to fit a ML model, is the vector colomn with all the features and the label column. Let's select those two columns

In [14]:
selectedcols = ["label", "features"] 
dataset = preppedDataDF.select(selectedcols)
display(dataset)

label,features
0.0,"List(0, 100, List(1, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 50.0, 83311.0, 13.0, 13.0))"
0.0,"List(0, 100, List(0, 8, 25, 38, 44, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 38.0, 215646.0, 9.0, 40.0))"
0.0,"List(0, 100, List(0, 13, 23, 38, 43, 49, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 53.0, 234721.0, 7.0, 40.0))"
0.0,"List(0, 100, List(0, 10, 23, 29, 47, 49, 62, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 28.0, 338409.0, 13.0, 40.0))"
0.0,"List(0, 100, List(0, 11, 23, 31, 47, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 37.0, 284582.0, 14.0, 40.0))"
0.0,"List(0, 100, List(0, 18, 28, 34, 44, 49, 64, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 49.0, 160187.0, 5.0, 16.0))"
1.0,"List(0, 100, List(1, 8, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 52.0, 209642.0, 9.0, 45.0))"
1.0,"List(0, 100, List(0, 11, 24, 29, 44, 48, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 31.0, 45781.0, 14.0, 14084.0, 50.0))"
1.0,"List(0, 100, List(0, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 42.0, 159449.0, 13.0, 5178.0, 40.0))"
1.0,"List(0, 100, List(0, 9, 23, 31, 43, 49, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 37.0, 280464.0, 10.0, 80.0))"


## Fitting the machine learning models
Before fitting a machine learning model, we should split our data in a training and test set. Then we can  set up the model and fit it on the training data.
Let's fit a logistic regression model, a decision tree and a random forest.

In [16]:
### Randomly split data into training and test sets. set seed for reproducibility
(trainingData, testData) = preppedDataDF.randomSplit([0.7, 0.3], seed=100)

from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier

# Create initial LogisticRegression model
lr = LogisticRegression(labelCol="label", featuresCol="features", maxIter=10)
dt = DecisionTreeClassifier(labelCol = 'label', featuresCol = 'features',  maxDepth = 3)
rf = RandomForestClassifier(featuresCol = 'features', labelCol = 'label', numTrees=100)

# Train model with Training Data
lrModel = lr.fit(trainingData)
dtModel = dt.fit(trainingData)

rfModel = rf.fit(trainingData)

In [17]:
# Make predictions on test data using the transform() method for the three models
predictions_LR = lrModel.transform(testData)
predictions_DT = dtModel.transform(testData)
predictions_RF = rfModel.transform(testData)


## Evaluating the models
Let's evaluate the three models in terms of their area under the ROC curve.

In [19]:
display(predictions_LR)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,workclassIndex,workclassclassVec,educationIndex,educationclassVec,marital_statusIndex,marital_statusclassVec,occupationIndex,occupationclassVec,relationshipIndex,relationshipclassVec,raceIndex,raceclassVec,sexIndex,sexclassVec,native_countryIndex,native_countryclassVec,label,features,rawPrediction,probability,prediction
17.0,?,48703.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,30.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",5.0,"List(0, 15, List(5), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 48703.0, 7.0, 30.0))","List(1, 2, List(), List(8.361640192296251, -8.361640192296251))","List(1, 2, List(), List(0.999766393813901, 2.3360618609902863E-4))",0.0
17.0,?,67808.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",0.0,"List(0, 1, List(0), List(1.0))",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 67808.0, 6.0, 40.0))","List(1, 2, List(), List(7.666454361904844, -7.666454361904844))","List(1, 2, List(), List(0.999531943960885, 4.6805603911499803E-4))",0.0
17.0,?,80077.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",5.0,"List(0, 15, List(5), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 80077.0, 7.0, 20.0))","List(1, 2, List(), List(8.684617598867284, -8.684617598867284))","List(1, 2, List(), List(0.9998308605021862, 1.6913949781383598E-4))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(8.220432560972577, -8.220432560972577))","List(1, 2, List(), List(0.9997309737564378, 2.690262435622275E-4))",0.0
17.0,?,112942.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",0.0,"List(0, 1, List(0), List(1.0))",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 112942.0, 6.0, 40.0))","List(1, 2, List(), List(7.610679099376499, -7.610679099376499))","List(1, 2, List(), List(0.9995051095397631, 4.948904602367223E-4))",0.0
17.0,?,138507.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,20.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",0.0,"List(0, 1, List(0), List(1.0))",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 138507.0, 6.0, 20.0))","List(1, 2, List(), List(8.30258356436842, -8.30258356436842))","List(1, 2, List(), List(0.9997521858305097, 2.4781416949023273E-4))",0.0
17.0,?,139183.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,15.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 139183.0, 6.0, 15.0))","List(1, 2, List(), List(9.0600530678223, -9.0600530678223))","List(1, 2, List(), List(0.9998837966957285, 1.1620330427157354E-4))",0.0
17.0,?,145258.0,11th,7.0,Never-married,?,Other-relative,White,Female,0.0,0.0,25.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",5.0,"List(0, 15, List(5), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",5.0,"List(0, 5, List(), List())",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 13, 24, 36, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 145258.0, 7.0, 25.0))","List(1, 2, List(), List(6.622593610621037, -6.622593610621037))","List(1, 2, List(), List(0.9986717894635513, 0.0013282105364488199))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",2.0,"List(0, 5, List(2), List(1.0))",0.0,"List(0, 4, List(0), List(1.0))",1.0,"List(0, 1, List(), List())",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(8.854983684598722, -8.854983684598722))","List(1, 2, List(), List(0.9998573514069589, 1.4264859304110482E-4))",0.0
17.0,?,161259.0,10th,6.0,Never-married,?,Other-relative,White,Male,0.0,0.0,12.0,United-States,<=50K,3.0,"List(0, 8, List(3), List(1.0))",7.0,"List(0, 15, List(7), List(1.0))",1.0,"List(0, 6, List(1), List(1.0))",7.0,"List(0, 14, List(7), List(1.0))",5.0,"List(0, 5, List(), List())",0.0,"List(0, 4, List(0), List(1.0))",0.0,"List(0, 1, List(0), List(1.0))",0.0,"List(0, 41, List(0), List(1.0))",0.0,"List(0, 100, List(3, 15, 24, 36, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161259.0, 6.0, 12.0))","List(1, 2, List(), List(6.763265088041763, -6.763265088041763))","List(1, 2, List(), List(0.9988458831636885, 0.0011541168363114961))",0.0


In [20]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
# Evaluate model
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
print(evaluator.evaluate(predictions_LR))
print(evaluator.evaluate(predictions_DT))
print(evaluator.evaluate(predictions_RF))
