Spark Iteration Time Increasing Exponentially When Using Join

Spark iteration time increasing exponentially when using join

Summary:

Generally speaking iterative algorithms, especially ones with self-join or self-union, require a control over:

  • Length of the lineage (see for example Stackoverflow due to long RDD Lineage and unionAll resulting in StackOverflow).
  • Number of partitions.

Problem described here is a result of the lack of the former one. In each iteration number of partition increases with self-join leading to exponential pattern. To address that you have to either control number of partitions in each iteration (see below) or use global tools like spark.default.parallelism (see an answer provided by Travis). In general the first approach provides much more control in general and doesn't affect other parts of code.

Original answer:

As far as I can tell there are two interleaved problems here - growing number of partitions and shuffling overhead during joins. Both can be easily handled so lets go step by step.

First lets create a helper to collect the statistics:

import datetime

def get_stats(i, init, init2, init3, init4,
start, end, desc, cache, part, hashp):
return {
"i": i,
"init": init.getNumPartitions(),
"init1": init2.getNumPartitions(),
"init2": init3.getNumPartitions(),
"init4": init4.getNumPartitions(),
"time": str(end - start),
"timen": (end - start).seconds + (end - start).microseconds * 10 **-6,
"desc": desc,
"cache": cache,
"part": part,
"hashp": hashp
}

another helper to handle caching/partitioning

def procRDD(rdd, cache=True, part=False, hashp=False, npart=16):
rdd = rdd if not part else rdd.repartition(npart)
rdd = rdd if not hashp else rdd.partitionBy(npart)
return rdd if not cache else rdd.cache()

extract pipeline logic:

def run(init, description, cache=True, part=False, hashp=False, 
npart=16, n=6):
times = []

for i in range(n):
start = datetime.datetime.now()

init2 = procRDD(
init.map(lambda n: (n, n*3)),
cache, part, hashp, npart)
init3 = procRDD(
init.map(lambda n: (n, n*2)),
cache, part, hashp, npart)

# If part set to True limit number of the output partitions
init4 = init2.join(init3, npart) if part else init2.join(init3)
init = init4.map(lambda n: n[0])

if cache:
init4.cache()
init.cache()

init.count() # Force computations to get time
end = datetime.datetime.now()

times.append(get_stats(
i, init, init2, init3, init4,
start, end, description,
cache, part, hashp
))

return times

and create initial data:

ncores = 8
init = sc.parallelize(xrange(10000), ncores * 2).cache()

Join operation by itself, if numPartitions argument is not provided, adjust number of partitions in the output based on the number of partitions of the input RDDs. It means growing number of partitions with each iteration. If number of partitions is to large things get ugly. You can deal with these by providing numPartitions argument for join or repartition RDDs with each iteration.

timesCachePart = sqlContext.createDataFrame(
run(init, "cache + partition", True, True, False, ncores * 2))
timesCachePart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+-----------------+
|i|init1|init2|init4| time| desc|
+-+-----+-----+-----+--------------+-----------------+
|0| 16| 16| 16|0:00:01.145625|cache + partition|
|1| 16| 16| 16|0:00:01.090468|cache + partition|
|2| 16| 16| 16|0:00:01.059316|cache + partition|
|3| 16| 16| 16|0:00:01.029544|cache + partition|
|4| 16| 16| 16|0:00:01.033493|cache + partition|
|5| 16| 16| 16|0:00:01.007598|cache + partition|
+-+-----+-----+-----+--------------+-----------------+

As you can see when we repartition execution time is more or less constant.
The second problem is that above data is partitioned randomly. To ensure join performance we would like to have same keys on a single partition. To achieve that we can use hash partitioner:

timesCacheHashPart = sqlContext.createDataFrame(
run(init, "cache + hashpart", True, True, True, ncores * 2))
timesCacheHashPart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+----------------+
|i|init1|init2|init4| time| desc|
+-+-----+-----+-----+--------------+----------------+
|0| 16| 16| 16|0:00:00.946379|cache + hashpart|
|1| 16| 16| 16|0:00:00.966519|cache + hashpart|
|2| 16| 16| 16|0:00:00.945501|cache + hashpart|
|3| 16| 16| 16|0:00:00.986777|cache + hashpart|
|4| 16| 16| 16|0:00:00.960989|cache + hashpart|
|5| 16| 16| 16|0:00:01.026648|cache + hashpart|
+-+-----+-----+-----+--------------+----------------+

Execution time is constant as before and There is a small improvement over the basic partitioning.

Now lets use cache only as a reference:

timesCacheOnly = sqlContext.createDataFrame(
run(init, "cache-only", True, False, False, ncores * 2))
timesCacheOnly.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+----------+
|i|init1|init2|init4| time| desc|
+-+-----+-----+-----+--------------+----------+
|0| 16| 16| 32|0:00:00.992865|cache-only|
|1| 32| 32| 64|0:00:01.766940|cache-only|
|2| 64| 64| 128|0:00:03.675924|cache-only|
|3| 128| 128| 256|0:00:06.477492|cache-only|
|4| 256| 256| 512|0:00:11.929242|cache-only|
|5| 512| 512| 1024|0:00:23.284508|cache-only|
+-+-----+-----+-----+--------------+----------+

As you can see number of partitions (init2, init3, init4) for cache-only version doubles with each iteration and execution time is proportional to the number of partitions.

Finally we can check if we can improve performance with large number of partitions if we use hash partitioner:

timesCacheHashPart512 = sqlContext.createDataFrame(
run(init, "cache + hashpart 512", True, True, True, 512))
timesCacheHashPart512.select(
"i", "init1", "init2", "init4", "time", "desc").show()
+-+-----+-----+-----+--------------+--------------------+
|i|init1|init2|init4| time| desc|
+-+-----+-----+-----+--------------+--------------------+
|0| 512| 512| 512|0:00:14.492690|cache + hashpart 512|
|1| 512| 512| 512|0:00:20.215408|cache + hashpart 512|
|2| 512| 512| 512|0:00:20.408070|cache + hashpart 512|
|3| 512| 512| 512|0:00:20.390267|cache + hashpart 512|
|4| 512| 512| 512|0:00:20.362354|cache + hashpart 512|
|5| 512| 512| 512|0:00:19.878525|cache + hashpart 512|
+-+-----+-----+-----+--------------+--------------------+

Improvement is not so impressive but if you have a small cluster and a lot of data it is still worth trying.

I guess take away message here is partitioning matters. There are contexts where it is handled for you (mllib, sql) but if you use low level operations it is your responsibility.

Spark join exponentially slow

What's wrong is probably that Spark isn't noticing that you have an easy case of the join problem. When one of the two RDDs you're joining is so small you're better off with it not being an RDD. Then you can roll your own implementation of hash join, which is actually a lot simpler than it sounds. Basically, you need to:

  • Pull your category list out of the RDD using collect() -- the resulting communication will easily pay for itself (or, if possible, don't make it an RDD in the first place)
  • Turn it into a hash table with one entry containing all the values for one key (assuming your keys are not unique)
  • For each pair in your large RDD, look the key up in the hash table and produce one pair for each value in the list (if not found then that particular pair doesn't produce any results)

I have an implementation in Scala -- feel free to ask questions about translating it, but it should be pretty easy.

Another interesting possibility is to try using Spark SQL. I'm pretty sure the project's long term ambitions would include doing this for you automatically, but I don't know if they've achieved it yet.

How do I decrease iteration time when making data transformations?

Refactoring

For experimentation / fast iteration, it's often a good idea to refactor your code into several smaller steps instead of a single large step.

This way, you compute the upstream cells first, write the data back to Foundry, and use this pre-computed data in later steps. If you were to keep re-computing without changing these early steps' logic, you are doing nothing but extra work again and again.

Concretely:


from pyspark.sql import functions as F

# output = /my/function/output_max
# input_df = "/my/function/input
def my_compute_function(input_df):
"""Compute the max by key

Keyword arguments:
input_df (pyspark.sql.DataFrame) : input DataFrame

Returns:
pyspark.sql.DataFrame
"""

max_df = input_df \
.groupBy("key") \
.agg(F.max(F.col("val").alias("max_val")

return max_df

# output = /my/function/output_joined
# input_df = /my/function/input
# max_df = /my/function/output_max
def my_compute_function(max_df, input_df):
"""Compute the joined output of max and input

Keyword arguments:
max_df (pyspark.sql.DataFrame) : input DataFrame
input_df (pyspark.sql.DataFrame) : input DataFrame

Returns:
pyspark.sql.DataFrame
"""

joined_df = input_df \
.join(max_df, "key")

return joined_df

# Output = /my/function/output_diff
# joined_df = /my/function/output_joined
def my_compute_function(joined_df):
"""Compute difference from maximum of a value column by key

Keyword arguments:
joined_df (pyspark.sql.DataFrame) : input DataFrame

Returns:
pyspark.sql.DataFrame
"""

diff_df = joined_df \
.withColumn("diff", F.col("max_val") - F.col("val"))

return diff_df

The work you perform would instead look like:

pipeline_2: 
transform_A:
work_1: input -> max_df
(takes 4 iterations to get right): 4 * max_df
transform_B:
work_2: max_df -> joined_df
(takes 4 iterations to get right): 4 * joined_df
transform_C:
work:3: joined_df -> diff_df
(takes 4 iterations to get right): 4 * diff_df
total_work:
transform_A + transform_B + transform_C
= work_1 + work_2 + work_3
= 4 * max_df + 4 * joined_df + 4 * diff_df

If you assume max_df, joined_df, and diff_df all cost the same amount to compute, pipeline_1.total_work = 24 * max_df, whereas pipeline_2.total_work = 12 * max_df so you can expect something on the order of 2x speed improvement on iteration.

Caching

For any 'small' datasets, you should cache them. This will keep the rows in-memory for your pipeline and not require fetching from the written-back dataset. 'small' is somewhat arbitrary given a lot of different factors that must be considered, but Spark does a good job of trying to cache it no matter what and warning you if it's too big.

In this case, depending on your setup, you could cache the intermediate layers of max_df and joined_df depending on which step you are developing.

Function Calls

You should stick to native PySpark methods as much as possible and never user Python methods directly on data (i.e. looping over individual rows, executing a UDF). PySpark methods call the underlying Spark methods that are written in Scala and run directly against the data instead of the Python runtime, and if you simply use Python as the layer to interact with this system instead of being the system that interacts with the data, you will get all the performance benefits of Spark itself.

In the above example, only native PySpark methods are called, so this computation will be quite fast.

Downsampling

If you can derive your own accurate sample of a large input dataset, this can be used as the mock input for your transformations, until such time you perfect your logic and want to test it against the full set.

In the above case, we could downsample input_df to be a single key before executing any prior steps.

I personally down-sample and cache datasets above 1M rows before ever writing a line of PySpark code, that way my turnaround times are very fast and I don't ever catch syntax bugs slowly due to large dataset sizes.

All Together

A good development pipeline looks like:

  • Discrete chunks of code that do particular materializations you expect to re-use later but don't need to be recomputed over and over again
  • Downsampled to 'small' sizes
  • Cached 'small' datasets for very fast fetching
  • PySpark native code only that exploits the fast underlying Spark libraries

Evaluating Spark DataFrame in loop slows down with every iteration, all work done by controller

A simple reproduction scenario:

import time
from pyspark import SparkContext

sc = SparkContext()

def push_and_pop(rdd):
# two transformations: moves the head element to the tail
first = rdd.first()
return rdd.filter(
lambda obj: obj != first
).union(
sc.parallelize([first])
)

def serialize_and_deserialize(rdd):
# perform a collect() action to evaluate the rdd and create a new instance
return sc.parallelize(rdd.collect())

def do_test(serialize=False):
rdd = sc.parallelize(range(1000))
for i in xrange(25):
t0 = time.time()
rdd = push_and_pop(rdd)
if serialize:
rdd = serialize_and_deserialize(rdd)
print "%.3f" % (time.time() - t0)

do_test()

Shows major slowdown during the 25 iterations:

0.597
0.117
0.186
0.234
0.288
0.309
0.386
0.439
0.507
0.529
0.553
0.586
0.710
0.728
0.779
0.896
0.866
0.881
0.956
1.049
1.069
1.061
1.149
1.189
1.201

(first iteration is relatively slow because of initialization effects, second iteration is quick, every subsequent iteration is slower).

The cause seems to be the growing chain of lazy transformations. We can test the hypothesis by rolling up the RDD using an action.

do_test(True)

0.897
0.256
0.233
0.229
0.220
0.238
0.234
0.252
0.240
0.267
0.260
0.250
0.244
0.266
0.295
0.464
0.292
0.348
0.320
0.258
0.250
0.201
0.197
0.243
0.230

The collect(), parallelize() adds about 0.1 second to each iteration, but completely eliminates the incremental slowdown.

Spark join - (edges and vertices)

I will use Java but I guess it is straightforward to convert it to Scala.

Assuming

edgeRDD has type JavaPairRDD<String,String> and

vertexRDD has type JavaPairRDD<String,Long>:

  1. edgeRDD.join(vertexRDD) will yield JavaPairRDD<String,Tuple2<String,Long>> with the following content (let's call it join1):

    (V1, Tuple2(V2,1L)) 
    (V2, Tuple2(V3,2L))
    (V1, Tuple2(V4,1L))
  2. Then you convert join1 into another JavaPairRDD<String,Tuple2<String,Long>> by restructuring the keys and values using map (let's call it join2):

    (V2, Tuple2(V1,1L)) 
    (V3, Tuple2(V2,2L))
    (V4, Tuple2(V1,1L))
  3. Finally perform vertexRDD.join(join2) to get JavaPairRDD<String,Tuple2<Long,Tuple2<String,Long>>> with contents:

    (V2, Tuple2(2L, Tuple2(V1,1L)))
    (V3, Tuple2(3L, Tuple2(V2,2L)))
    (V4, Tuple2(4L, Tuple2(V1,1L)))

which you may pass through the map and create JavaRDD<String> (or a new JavaPairRDD) by combining keys and values appropriately within the map. I will leave mapping phases up to you.

Spark example program runs very slow

There can many reasons why this code doesn't perform particularly well on your machine but most likely this is just another variant of the problem described in Spark iteration time increasing exponentially when using join. The simplest way to check if it is indeed the case is to provide spark.default.parallelism parameter on submit:

bin/spark-submit --conf spark.default.parallelism=2 \
examples/src/main/python/transitive_closure.py

If not limited otherwise, SparkContext.union, RDD.join and RDD.union set a number of partitions of the child to the total number of partitions in the parents. Usually it is a desired behavior but can become extremely inefficient if applied iteratively.



Related Topics



Leave a reply



Submit