get TopN of all groups after group by using Spark DataFrame

You can use rank window function as follows

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}

val n: Int = ???

// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))

// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)

If you don't care about ties then you can replace rank with row_number

Retrieve top n in each group of a DataFrame in pyspark

I believe you need to use window functions to attain the rank of each row based on user_id and score, and subsequently filter your results to only keep the first two values.

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())'*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|

In general, the official programming guide is a good place to start learning Spark.


rdd = sc.parallelize([("user_1",  "object_1",  3), 
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

Spark sql top n per group

You can use the window function feature that was added in Spark 1.4
Suppose that we have a productRevenue table as shown below.
the answer to What are the best-selling and the second best-selling products in every category is as follows

SELECT product,category,revenue FROM 
(SELECT product,category,revenue,dense_rank()
OVER (PARTITION BY category ORDER BY revenue DESC) as rank
FROM productRevenue) tmp
WHERE rank <= 2

Spark topN values by group

I have settled on this solution. First aggregate the data by last reported time and then join it with the original DF to eliminate all unwanted data and run a rank on the resulting data.

     val df2= df.groupBy($"userid",$"eventid").agg(last($"eventtime") as "eventtime")
val lasteventdf=df.join(df2,Seq("eventtime", "userid","eventid"))
val w = Window.partitionBy($"userid",$"event_title",$"eventid").orderBy($"eventtime".desc)
val contentByRank = lasteventdf.withColumn("rank", dense_rank().over(w)).filter($"rank" <= 5),false)

|eventtime |userid |eventid|event_title|eventdata |rank|
|20180515114049|user004|e002 |cross-limit|some data related to event |1 |
|20180715114049|user005|e004 |cross-over |some data relat7ed to event |1 |
|20180815114049|user006|e001 |cross-over |some data re22lated to event|1 |
|20180715114049|user005|e003 |no-cross |some data relate6d to event |1 |
|20180715114049|user005|e003 |no-cross |some data rel9ated to event |1 |
|20180715114049|user005|e005 |dl-over |some data relat8ed to event |1 |
|20180715114049|user003|e004 |cross-cl |some data related2 to event |1 |
|20180715114049|user005|e001 |cross-over |some data related4 to event |1 |
|20180105114049|user001|e006 |straight |some data relat4ed to event |1 |
|20180715114049|user005|e002 |cross-limit|some data related5 to event |1 |
|20180915114049|user001|e001 |cross-over |some data rel3ated to event |1 |

Spark, Scala - How to get Top 3 value from each group of two column in dataframe

the solution is similar to Retrieve top n in each group of a DataFrame in pyspark which is in pyspark

If you do the same in scala, then it should be as below

df.withColumn("rank", rank().over(Window.partitionBy("Dept_id").orderBy($"salary".desc)))
.filter($"rank" <= 3)

I hope the answer is helpful

How to select the first row of each group?

Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+

This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future (SPARK-37099).

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:


Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop =$"Hour", struct($"TotalValue", $"Category").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)[Record]
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)

// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 or later:[Record]
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to


which internally uses equivalent execution plan.

Spark: Performant way to find top n values

I see two methods to improve your algorithm performance. First is to use sort and limit to retrieve the top n rows. The second is to develop your custom Aggregator.

Sort and Limit method

You sort your dataframe and then you take the first n rows:

val n: Int = ???

import org.apache.spark.functions.sql.desc


Spark optimizes this kind of transformations sequence by first performing sort on each partition, taking first n rows on each partition, retrieving it on a final partition and reperforming sort and taking final first n rows. You can check this by executing explain() on transformations. You get the following execution plan:

== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]

And by looking how TakeOrderedAndProject step is executed in limit.scala in Spark's source code (case class TakeOrderedAndProjectExec, method doExecute).

Custom Aggregator method

For custom aggregator, you create an Aggregator that will populate and update an ordered array of top n rows.

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

import scala.collection.mutable.ArrayBuffer

case class Record(id: String, count: Int)

case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {

def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]

def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
if (topRecords.length < limit) {
topRecords.insert(insertIndex + 1, currentRecord)
} else if (insertIndex < limit - 1) {
topRecords.insert(insertIndex + 1, currentRecord)
topRecords.remove(topRecords.length - 1)

def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
val merged = ArrayBuffer.empty[Record]
while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
if (topRecords1.isEmpty) {
} else if (topRecords2.isEmpty) {
} else if (topRecords2.head.count < topRecords1.head.count) {
} else {

def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction

def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]

def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]


And then you apply this aggregator on your dataframe, and flatten the aggregation result:

val n: Int = ???

import sparkSession.implicits._[Record].select(TopRecords(n).toColumn).flatMap(record => record)

Method comparison

To compare those two methods, let's say we want to take top n rows of a dataframe that is distributed on p partitions, each partition having around k records. So dataframe has size p·k. Which gives the following complexity (subject to errors):

