How to find median and quantiles using Spark
Ongoing work
SPARK-30569 - Add DSL functions invoking percentile_approx
Spark 2.0+:
You can use approxQuantile
method which implements Greenwald-Khanna algorithm:
Python:
df.approxQuantile("x", [0.5], 0.25)
Scala:
df.stat.approxQuantile("x", Array(0.5), 0.25)
where the last parameter is a relative error. The lower the number the more accurate results and more expensive computation.
Since Spark 2.2 (SPARK-14352) it supports estimation on multiple columns:
df.approxQuantile(["x", "y", "z"], [0.5], 0.25)
and
df.approxQuantile(Array("x", "y", "z"), Array(0.5), 0.25)
Underlying methods can be also used in SQL aggregation (both global and groped) using approx_percentile
function:
> SELECT approx_percentile(10.0, array(0.5, 0.4, 0.1), 100);
[10.0,10.0,10.0]
> SELECT approx_percentile(10.0, 0.5, 100);
10.0
Spark < 2.0
Python
As I've mentioned in the comments it is most likely not worth all the fuss. If data is relatively small like in your case then simply collect and compute median locally:
import numpy as np
np.random.seed(323)
rdd = sc.parallelize(np.random.randint(1000000, size=700000))
%time np.median(rdd.collect())
np.array(rdd.collect()).nbytes
It takes around 0.01 second on my few years old computer and around 5.5MB of memory.
If data is much larger sorting will be a limiting factor so instead of getting an exact value it is probably better to sample, collect, and compute locally. But if you really want a to use Spark something like this should do the trick (if I didn't mess up anything):
from numpy import floor
import time
def quantile(rdd, p, sample=None, seed=None):
"""Compute a quantile of order p ∈ [0, 1]
:rdd a numeric rdd
:p quantile(between 0 and 1)
:sample fraction of and rdd to use. If not provided we use a whole dataset
:seed random number generator seed to be used with sample
"""
assert 0 <= p <= 1
assert sample is None or 0 < sample <= 1
seed = seed if seed is not None else time.time()
rdd = rdd if sample is None else rdd.sample(False, sample, seed)
rddSortedWithIndex = (rdd.
sortBy(lambda x: x).
zipWithIndex().
map(lambda (x, i): (i, x)).
cache())
n = rddSortedWithIndex.count()
h = (n - 1) * p
rddX, rddXPlusOne = (
rddSortedWithIndex.lookup(x)[0]
for x in int(floor(h)) + np.array([0L, 1L]))
return rddX + (h - floor(h)) * (rddXPlusOne - rddX)
And some tests:
np.median(rdd.collect()), quantile(rdd, 0.5)
## (500184.5, 500184.5)
np.percentile(rdd.collect(), 25), quantile(rdd, 0.25)
## (250506.75, 250506.75)
np.percentile(rdd.collect(), 75), quantile(rdd, 0.75)
(750069.25, 750069.25)
Finally lets define median:
from functools import partial
median = partial(quantile, p=0.5)
So far so good but it takes 4.66 s in a local mode without any network communication. There is probably way to improve this, but why even bother?
Language independent (Hive UDAF):
If you use HiveContext
you can also use Hive UDAFs. With integral values:
rdd.map(lambda x: (float(x), )).toDF(["x"]).registerTempTable("df")
sqlContext.sql("SELECT percentile_approx(x, 0.5) FROM df")
With continuous values:
sqlContext.sql("SELECT percentile(x, 0.5) FROM df")
In percentile_approx
you can pass an additional argument which determines a number of records to use.
Median / quantiles within PySpark groupBy
I guess you don't need it anymore. But will leave it here for future generations (i.e. me next week when I forget).
from pyspark.sql import Window
import pyspark.sql.functions as F
grp_window = Window.partitionBy('grp')
magic_percentile = F.expr('percentile_approx(val, 0.5)')
df.withColumn('med_val', magic_percentile.over(grp_window))
Or to address exactly your question, this also works:
df.groupBy('grp').agg(magic_percentile.alias('med_val'))
And as a bonus, you can pass an array of percentiles:
quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))')
And you'll get a list in return.
Compute median of column in pyspark
You need to add a column with withColumn
because approxQuantile
returns a list of floats, not a Spark column.
import pyspark.sql.functions as F
df2 = df.withColumn('count_media', F.lit(df.approxQuantile('count',[0.5],0.1)[0]))
df2.show()
+-----------+-----+-----------+
|parsed_date|count|count_media|
+-----------+-----+-----------+
| 2017-12-16| 2| 2.0|
| 2017-12-16| 2| 2.0|
| 2017-12-17| 2| 2.0|
| 2017-12-17| 2| 2.0|
| 2017-12-18| 1| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-19| 4| 2.0|
| 2017-12-20| 1| 2.0|
+-----------+-----+-----------+
You can also use the approx_percentile
/ percentile_approx
function in Spark SQL:
import pyspark.sql.functions as F
df2 = df.withColumn('count_media', F.expr("approx_percentile(count, 0.5, 10) over ()"))
df2.show()
+-----------+-----+-----------+
|parsed_date|count|count_media|
+-----------+-----+-----------+
| 2017-12-16| 2| 2|
| 2017-12-16| 2| 2|
| 2017-12-17| 2| 2|
| 2017-12-17| 2| 2|
| 2017-12-18| 1| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-19| 4| 2|
| 2017-12-20| 1| 2|
+-----------+-----+-----------+
Median and quantile values in Pyspark
The first improvment to do would be to do all the quantile calculations at the same time:
quantiles = df.approxQuantile("age", [0.25, 0.5, 0.75], 0)
Also, note that you use the exact calculation of the quantiles. From the documentation we can see that (emphasis added by me):
relativeError – The relative target precision to achieve (>= 0). If set to zero, the exact quantiles are computed, which could be very expensive. Note that values greater than 1 are accepted but give the same result as 1.
Since you have a very large dataframe I expect that some error is acceptable in these calculations, but it will be a trade-off between speed and precision (although anything more than 0 could have a significant speed improvement).
How to find the median in Apache Spark with Python Dataframe API?
Here is an example implementation with Dataframe API in Python (Spark 1.6 +).
import pyspark.sql.functions as F
import numpy as np
from pyspark.sql.types import FloatType
Let's assume we have monthly salaries for customers in "salaries" spark dataframe such as:
month | customer_id | salary
and we would like to find the median salary per customer throughout all the months
Step1: Write a user defined function to calculate the median
def find_median(values_list):
try:
median = np.median(values_list) #get the median of values in a list in each row
return round(float(median),2)
except Exception:
return None #if there is anything wrong with the given values
median_finder = F.udf(find_median,FloatType())
Step 2: Aggregate on the salary column by collecting them into a list of salaries in each row:
salaries_list = salaries.groupBy("customer_id").agg(F.collect_list("salary").alias("salaries"))
Step 3: Call the median_finder udf on the salaries column and add the median values as a new column
salaries_list = salaries_list.withColumn("median",median_finder("salaries"))
How to find exact median for grouped data in Spark
Simplest Approach (requires Spark 2.0.1+ and not exact median)
As noted in the comments in reference to the first question Find median in Spark SQL for double datatype columns, we can use percentile_approx
to calculate median for Spark 2.0.1+. To apply this for grouped data in Apache Spark, the query would look like:
val df = Seq(("A", 0.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)).toDF("id", "num")
df.createOrReplaceTempView("df")
spark.sql("select id, percentile_approx(num, 0.5) as median from df group by id order by id").show()
with the output being:
+---+------+
| id|median|
+---+------+
| A| 1.0|
| B| 1.0|
+---+------+
Saying this, this is an approximate value (as opposed to an exact median per the question).
Calculate exact median for grouped data
There are multiple approaches so I'm sure others in SO can provide better or more efficient examples. But here's a code snippet calculate the median for grouped data in Spark (verified in Spark 1.6 and Spark 2.1):
import org.apache.spark.SparkContext._
val rdd: RDD[(String, Double)] = sc.parallelize(Seq(("A", 1.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 0.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)))
// Scala median function
def median(inputList: List[Double]): Double = {
val count = inputList.size
if (count % 2 == 0) {
val l = count / 2 - 1
val r = l + 1
(inputList(l) + inputList(r)).toDouble / 2
} else
inputList(count / 2).toDouble
}
// Sort the values
val setRDD = rdd.groupByKey()
val sortedListRDD = setRDD.mapValues(_.toList.sorted)
// Output DataFrame of id and median
sortedListRDD.map(m => {
(m._1, median(m._2))
}).toDF("id", "median_of_num").show()
with the output being:
+---+-------------+
| id|median_of_num|
+---+-------------+
| A| 1.0|
| B| 1.0|
+---+-------------+
There are some caveats that I should call out as this likely isn't the most efficient implementation:
- It's currently using a
groupByKey
which is not very performant. You may want to change this into areduceByKey
instead (more information at Avoid GroupByKey) - Using a Scala function to calculate the
median
.
This approach should work okay for smaller amounts of data but if you have millions of rows for each key, would advise utilizing Spark 2.0.1+ and using the percentile_approx
approach.
Related Topics
Getting the Index of the Returned Max or Min Item Using Max()/Min() on a List
Update Value of a Nested Dictionary of Varying Depth
What Is the Purpose of "Pip Install --User ..."
How to Check If a Word Is an English Word with Python
What's the Fastest Way of Checking If a Point Is Inside a Polygon in Python
How to Do a Recursive Sub-Folder Search and Return Files in a List
Python Try...Except Comma VS 'As' in Except
Python Strings and Integer Concatenation
Post-Install Script with Python Setuptools
Python Nameerror: Name Is Not Defined
Check for Presence of a Sliced List in Python
Why Don't These List Operations Return the Resulting List
Is It a Good Practice to Use Try-Except-Else in Python
How to Do Exponential and Logarithmic Curve Fitting in Python? I Found Only Polynomial Fitting