API Serving

Use MMLSpark

Load in required libraries

1
from pyspark.ml.tuning import CrossValidatorModel
2
from pyspark.ml import PipelineModel
3
from pyspark.sql.types import IntegerType
4
from pyspark.sql.functions import col, round
5
6
import sys
7
import numpy as np
8
import pandas as pd
9
import mmlspark
10
from pyspark.sql.functions import col, from_json
11
from pyspark.sql.types import *
12
import uuid
13
from mmlspark import request_to_string, string_to_response
Copied!

Load in transformation pipeline and trained model

1
## Load in the transformation pipeline
2
mypipeline = PipelineModel.load("/mnt/trainedmodels/pipeline/")
3
4
## Load in trained model
5
mymodel = CrossValidatorModel.load("/mnt/trainedmodels/lr")
Copied!

Define username, key, and IP address

1
username = "admin"
2
ip = "10.0.0.4" #Internal IP
3
sas_url = "" # SAS Token for your VM's Private Key in Blob
Copied!

Define input schema

1
input_schema = StructType([
2
StructField("id", IntegerType(), True),
3
StructField("x1", IntegerType(), True),
4
StructField("x2", DoubleType(), True),
5
StructField("x3", StringType(), True),
6
])
Copied!

Set up streaming DataFrame

1
serving_inputs = spark.readStream.continuousServer() \
2
.option("numPartitions", 1) \
3
.option("name", "http://10.0.0.4:8898/my_api") \
4
.option("forwarding.enabled", True) \
5
.option("forwarding.username", username) \
6
.option("forwarding.sshHost", ip) \
7
.option("forwarding.keySas", sas_url) \
8
.address("localhost", 8898, "my_api") \
9
.load()\
10
.parseRequest(input_schema)
11
12
mydataset = mypipeline.transform(serving_inputs)
13
14
serving_outputs = mymodel.bestModel.transform(mydataset) \
15
.makeReply("prediction")
16
17
# display(serving_inputs)
Copied!

Set up server

1
server = serving_outputs.writeStream \
2
.continuousServer() \
3
.trigger(continuous="1 second") \
4
.replyTo("my_api") \
5
.queryName("my_query") \
6
.option("checkpointLocation", "file:///tmp/checkpoints-{}".format(uuid.uuid1())) \
7
.start()
Copied!

Test the webservice

1
import requests
2
data = u'{"id":0,"x1":1,"x2":2.0,"x3":"3"}'
3
4
#r = requests.post(data=data, url="http://localhost:8898/my_api") # Locally
5
r = requests.post(data=data, url="http://102.208.216.32:8902/my_api") # Via the VM IP
6
7
print("Response {}".format(r.text))
Copied!
You may need to run sudo netstat -tulpn to see what port is open if you're running inside Databricks.
Use this command to look for the port that was opened by the server.

Resources:

Microsoft MMLSpark on GitHub: https://github.com/Azure/mmlspark