How to Find Median and Quantiles Using Spark

How to find median and quantiles using Spark

Ongoing work

SPARK-30569 - Add DSL functions invoking percentile_approx

Spark 2.0+:

You can use approxQuantile method which implements Greenwald-Khanna algorithm:

Python:

df.approxQuantile("x", [0.5], 0.25)

Scala:

df.stat.approxQuantile("x", Array(0.5), 0.25)

where the last parameter is a relative error. The lower the number the more accurate results and more expensive computation.

Since Spark 2.2 (SPARK-14352) it supports estimation on multiple columns:

df.approxQuantile(["x", "y", "z"], [0.5], 0.25)

and

df.approxQuantile(Array("x", "y", "z"), Array(0.5), 0.25)

Underlying methods can be also used in SQL aggregation (both global and groped) using approx_percentile function:

> SELECT approx_percentile(10.0, array(0.5, 0.4, 0.1), 100);
[10.0,10.0,10.0]
> SELECT approx_percentile(10.0, 0.5, 100);
10.0

Spark < 2.0

Python

As I've mentioned in the comments it is most likely not worth all the fuss. If data is relatively small like in your case then simply collect and compute median locally:

import numpy as np

np.random.seed(323)
rdd = sc.parallelize(np.random.randint(1000000, size=700000))

%time np.median(rdd.collect())
np.array(rdd.collect()).nbytes

It takes around 0.01 second on my few years old computer and around 5.5MB of memory.

If data is much larger sorting will be a limiting factor so instead of getting an exact value it is probably better to sample, collect, and compute locally. But if you really want a to use Spark something like this should do the trick (if I didn't mess up anything):

from numpy import floor
import time

def quantile(rdd, p, sample=None, seed=None):
"""Compute a quantile of order p ∈ [0, 1]
:rdd a numeric rdd
:p quantile(between 0 and 1)
:sample fraction of and rdd to use. If not provided we use a whole dataset
:seed random number generator seed to be used with sample
"""
assert 0 <= p <= 1
assert sample is None or 0 < sample <= 1

seed = seed if seed is not None else time.time()
rdd = rdd if sample is None else rdd.sample(False, sample, seed)

rddSortedWithIndex = (rdd.
sortBy(lambda x: x).
zipWithIndex().
map(lambda (x, i): (i, x)).
cache())

n = rddSortedWithIndex.count()
h = (n - 1) * p

rddX, rddXPlusOne = (
rddSortedWithIndex.lookup(x)[0]
for x in int(floor(h)) + np.array([0L, 1L]))

return rddX + (h - floor(h)) * (rddXPlusOne - rddX)

And some tests:

np.median(rdd.collect()), quantile(rdd, 0.5)
## (500184.5, 500184.5)
np.percentile(rdd.collect(), 25), quantile(rdd, 0.25)
## (250506.75, 250506.75)
np.percentile(rdd.collect(), 75), quantile(rdd, 0.75)
(750069.25, 750069.25)

Finally lets define median:

from functools import partial
median = partial(quantile, p=0.5)

So far so good but it takes 4.66 s in a local mode without any network communication. There is probably way to improve this, but why even bother?

Language independent (Hive UDAF):

If you use HiveContext you can also use Hive UDAFs. With integral values:

rdd.map(lambda x: (float(x), )).toDF(["x"]).registerTempTable("df")

sqlContext.sql("SELECT percentile_approx(x, 0.5) FROM df")

With continuous values:

sqlContext.sql("SELECT percentile(x, 0.5) FROM df")

In percentile_approx you can pass an additional argument which determines a number of records to use.

Median / quantiles within PySpark groupBy

I guess you don't need it anymore. But will leave it here for future generations (i.e. me next week when I forget).

from pyspark.sql import Window
import pyspark.sql.functions as F

grp_window = Window.partitionBy('grp')
magic_percentile = F.expr('percentile_approx(val, 0.5)')

df.withColumn('med_val', magic_percentile.over(grp_window))

Or to address exactly your question, this also works:

df.groupBy('grp').agg(magic_percentile.alias('med_val'))

And as a bonus, you can pass an array of percentiles:

quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))')

And you'll get a list in return.

Compute median of column in pyspark

You need to add a column with withColumn because approxQuantile returns a list of floats, not a Spark column.

import pyspark.sql.functions as F

df2 = df.withColumn('count_media', F.lit(df.approxQuantile('count',[0.5],0.1)[0]))

df2.show()
+-----------+-----+-----------+
|parsed_date|count|count_media|
+-----------+-----+-----------+
| 2017-12-16| 2| 2.0|
| 2017-12-16| 2| 2.0|
| 2017-12-17| 2| 2.0|
| 2017-12-17| 2| 2.0|
| 2017-12-18| 1| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-20| 1| 2.0|
+-----------+-----+-----------+

You can also use the approx_percentile / percentile_approx function in Spark SQL:

import pyspark.sql.functions as F

df2 = df.withColumn('count_media', F.expr("approx_percentile(count, 0.5, 10) over ()"))

df2.show()
+-----------+-----+-----------+
|parsed_date|count|count_media|
+-----------+-----+-----------+
| 2017-12-16| 2| 2|
| 2017-12-16| 2| 2|
| 2017-12-17| 2| 2|
| 2017-12-17| 2| 2|
| 2017-12-18| 1| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-20| 1| 2|
+-----------+-----+-----------+

Median and quantile values in Pyspark

The first improvment to do would be to do all the quantile calculations at the same time:

quantiles = df.approxQuantile("age", [0.25, 0.5, 0.75], 0)

Also, note that you use the exact calculation of the quantiles. From the documentation we can see that (emphasis added by me):

relativeError – The relative target precision to achieve (>= 0). If set to zero, the exact quantiles are computed, which could be very expensive. Note that values greater than 1 are accepted but give the same result as 1.

Since you have a very large dataframe I expect that some error is acceptable in these calculations, but it will be a trade-off between speed and precision (although anything more than 0 could have a significant speed improvement).

How to find the median in Apache Spark with Python Dataframe API?

Here is an example implementation with Dataframe API in Python (Spark 1.6 +).

import pyspark.sql.functions as F
import numpy as np
from pyspark.sql.types import FloatType

Let's assume we have monthly salaries for customers in "salaries" spark dataframe such as:

month | customer_id | salary

and we would like to find the median salary per customer throughout all the months

Step1: Write a user defined function to calculate the median

def find_median(values_list):
try:
median = np.median(values_list) #get the median of values in a list in each row
return round(float(median),2)
except Exception:
return None #if there is anything wrong with the given values

median_finder = F.udf(find_median,FloatType())

Step 2: Aggregate on the salary column by collecting them into a list of salaries in each row:

salaries_list = salaries.groupBy("customer_id").agg(F.collect_list("salary").alias("salaries"))

Step 3: Call the median_finder udf on the salaries column and add the median values as a new column

salaries_list = salaries_list.withColumn("median",median_finder("salaries")) 

How to find exact median for grouped data in Spark

Simplest Approach (requires Spark 2.0.1+ and not exact median)

As noted in the comments in reference to the first question Find median in Spark SQL for double datatype columns, we can use percentile_approx to calculate median for Spark 2.0.1+. To apply this for grouped data in Apache Spark, the query would look like:

val df = Seq(("A", 0.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)).toDF("id", "num")
df.createOrReplaceTempView("df")
spark.sql("select id, percentile_approx(num, 0.5) as median from df group by id order by id").show()

with the output being:

+---+------+
| id|median|
+---+------+
| A| 1.0|
| B| 1.0|
+---+------+

Saying this, this is an approximate value (as opposed to an exact median per the question).

Calculate exact median for grouped data

There are multiple approaches so I'm sure others in SO can provide better or more efficient examples. But here's a code snippet calculate the median for grouped data in Spark (verified in Spark 1.6 and Spark 2.1):

import org.apache.spark.SparkContext._

val rdd: RDD[(String, Double)] = sc.parallelize(Seq(("A", 1.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 0.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)))

// Scala median function
def median(inputList: List[Double]): Double = {
val count = inputList.size
if (count % 2 == 0) {
val l = count / 2 - 1
val r = l + 1
(inputList(l) + inputList(r)).toDouble / 2
} else
inputList(count / 2).toDouble
}

// Sort the values
val setRDD = rdd.groupByKey()
val sortedListRDD = setRDD.mapValues(_.toList.sorted)

// Output DataFrame of id and median
sortedListRDD.map(m => {
(m._1, median(m._2))
}).toDF("id", "median_of_num").show()

with the output being:

+---+-------------+
| id|median_of_num|
+---+-------------+
| A| 1.0|
| B| 1.0|
+---+-------------+

There are some caveats that I should call out as this likely isn't the most efficient implementation:

  • It's currently using a groupByKey which is not very performant. You may want to change this into a reduceByKey instead (more information at Avoid GroupByKey)
  • Using a Scala function to calculate the median.

This approach should work okay for smaller amounts of data but if you have millions of rows for each key, would advise utilizing Spark 2.0.1+ and using the percentile_approx approach.



Related Topics



Leave a reply



Submit