Deep Learning with Spark in Deep Java Library in 10 minutes

Qing Lan
Towards Data Science
6 min readJun 10, 2020

--

Introduction

Apache Spark is a widely used technology for data processing and is used heavily by machine learning users. Spark can be used to classify products, forecast demand, and personalize recommendations. While Spark supports a variety of programming languages, the preferred Spark SDK is implemented for Scala, which is not well supported by most deep learning frameworks. Most machine learning frameworks favor Python with their SDKs, leaving Spark developers with suboptimal options: porting their code to Python or implementing a custom Scala wrapper. These options impact developer velocity and threaten production environments with brittle code.

In this blog, we demonstrate how users can execute deep learning workloads directly from Scala using the Deep Java Library (DJL). DJL is a framework-agnostic library developed to provide deep learning directly in Spark jobs developed with Java. In the following tutorial, we will walk through an image classification scenario using MXNet, though PyTorch and TensorFlow are also supported. For the full code, see the DJL Spark Image Classification Example.

Example: Image Classification with DJL and Spark

In this tutorial, we use resnet50, a pre-trained model to run inference. For the tutorial, we will use a single cluster with three worker nodes for classification. The workflow is shown in the following diagram:

Sample Image Classification workflow on Spark with 3 Worker nodes

Our example will create several Executors in the process and assign tasks to each of them. Each Executor contains one or more cores that execute tasks in different threads. This provides each worker node with a balanced workload for big data processing.

Step 1. Creating the Spark project

We use sbt, a popular open source tool, to build this Spark project in Scala. You can find more resources on how to get started with sbt here. We define our project in sbt using the following code block:

name := "sparkExample"version := "0.1"scalaVersion := "2.11.12"
scalacOptions += "-target:jvm-1.8"
resolvers += Resolver.mavenLocallibraryDependencies += "org.apache.spark" %% "spark-core" % "2.3.0"libraryDependencies += "ai.djl" % "api" % "0.5.0"
libraryDependencies += "ai.djl" % "repository" % "0.5.0"
// Using MXNet Engine
libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.5.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.6.0"

This tutorial uses MXNet as its underlying engine. Switching to another framework as is trivial as shown in the example below:

// Using PyTorch Engine
libraryDependencies += "ai.djl.pytorch" % "pytorch-model-zoo" % "0.5.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-auto" % "1.5.0"

Step 2: Configuring Spark

In this tutorial, we run this example on the local machine. The Spark application will use the following configuration:

// Spark configuration
val conf = new SparkConf()
.setAppName("Simple Image Classification")
.setMaster("local[*]")
.setExecutorEnv("MXNET_ENGINE_TYPE", "NaiveEngine")
val sc = new SparkContext(conf)

The NaiveEngine argument is required for multi-threaded inference in MXNet. If using PyTorch or TensorFlow, the following line can be removed:

.setExecutorEnv("MXNET_ENGINE_TYPE", "NaiveEngine")

Step 3: Identify the input data

In this tutorial, the input data is represented as a folder containing the images to classify. Spark will load these binary files and partition them into different partitions. Each partition is executed by one Executor. The following statement will distribute all images in the folder evenly across each partition.

val partitions = sc.binaryFiles("images/*")

Step 4: Define the Spark job

Next, we create the execution graph for this job using the partitions created in the previous step. In Spark, each Executor executes tasks in a multi-threaded fashion. As a result, we need to load each model into the Executors before performing the inferences. We set this with the following code:

// Start assign work for each worker node
val result = partitions.mapPartitions( partition => {
// before classification
val criteria = Criteria.builder
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(classOf[BufferedImage], classOf[Classifications])
.optFilter("dataset", "imagenet")
.optFilter("layers", "50")
.optProgress(new ProgressBar)
.build
val model = ModelZoo.loadModel(criteria)
val predictor = model.newPredictor()
// classification
partition.map(streamData => {
val img = ImageIO.read(streamData._2.open())
predictor.predict(img).toString
})
})

The criteria for ModelZoo needs to be specified for each partition to locate the corresponding model and create predictor. During the classification process, we load the images from RDD and create inferences for them.

This model was trained with the ImageNet dataset and stored in DJL ModelZoo.

Step 5: Define the output location

After we finish the map procedure, the master node will collect, aggregate and save the result in the file system.

result.collect().foreach(print)
result.saveAsTextFile("output")

Running this code yields the output classes listed previously. The output files will be saved to the output folder in different partitions. For the full code of this tutorial, please see the Scala example.

The expected output from the console:

[
class: "n02085936 Maltese dog, Maltese terrier, Maltese", probability: 0.81445
class: "n02096437 Dandie Dinmont, Dandie Dinmont terrier", probability: 0.08678
class: "n02098286 West Highland white terrier", probability: 0.03561
class: "n02113624 toy poodle", probability: 0.01261
class: "n02113712 miniature poodle", probability: 0.01200
][
class: "n02123045 tabby, tabby cat", probability: 0.52391
class: "n02123394 Persian cat", probability: 0.24143
class: "n02123159 tiger cat", probability: 0.05892
class: "n02124075 Egyptian cat", probability: 0.04563
class: "n03942813 ping-pong ball", probability: 0.01164
][
class: "n03770679 minivan", probability: 0.95839
class: "n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", probability: 0.01674
class: "n03769881 minibus", probability: 0.00610
class: "n03594945 jeep, landrover", probability: 0.00448
class: "n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", probability: 0.00278
]

In your production system

In this approach, we used RDD to execute our jobs for images for demo purposes. With the trend of DataFrame usages and save cache memory, production users should consider creating a schema for these images and storing them in the DataFrame format. Starting from Spark 3.0, Spark provided a binary file reader option, making images conversion to DataFrame even more convenient.

Case study

Amazon Retail System (ARS) uses DJL to make millions of predictions with large streams of data routed through Spark. These predictions determine a customer’s propensity to take an action across multiple categories using thousands of customer attributes, then render relevant ads and banners to the customer.

ARS uses hundred of thousands of features with hundreds of millions of customers — more than ten trillion data points. They needed a solution that can effectively scale. To solve this key issue, they originally created a Scala wrapper for their jobs, but the wrapper had memory issues and executed slowly. After adopting DJL, their solution worked perfectly with Spark and the total inference time dropped from days to hours.

Stay tuned for our next blog post where we will dive deep and introduce more of the challenges ARS faced and how they solved the issue by building a pipeline with DJL.

Learn more about DJL

After completing this tutorial, you may be wondering what DJL is. Deep Java Library (DJL) is an open source library to build and deploy deep learning in Java. This project launched in December 2019, and is used by engineering teams across Amazon. This effort was inspired by other DL frameworks, but was developed from the ground up to better suit Java development practices. DJL is framework agnostic, with support for Apache MXNet, PyTorch, TensorFlow 2.x (Experimental) and fastText (Experimental).

To learn more, check out our website, Github repository and Slack channel.

--

--