Retrieve top n in each group of a DataFrame in pyspark
I believe you need to use window functions to attain the rank of each row based on user_id
and score
, and subsequently filter your results to only keep the first two values.
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
In general, the official programming guide is a good place to start learning Spark.
Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
Sample random n rows from each group
The below would work if a sort isn't required, and it uses RDD transformations.
For a dataframe like the following
sdf.show()
# +-----------+-------+--------+----+
# |bvdidnumber|dt_year|dt_rfrnc|goal|
# +-----------+-------+--------+----+
# | 1| 2020| 202006| 0|
# | 1| 2020| 202012| 1|
# | 1| 2020| 202012| 0|
# | 1| 2021| 202103| 0|
# | 1| 2021| 202106| 0|
# | 1| 2021| 202112| 1|
# | 2| 2020| 202006| 0|
# | 2| 2020| 202012| 0|
# | 2| 2020| 202012| 1|
# | 2| 2021| 202103| 0|
# | 2| 2021| 202106| 0|
# | 2| 2021| 202112| 1|
# +-----------+-------+--------+----+
I created a function that can be shipped to all executors, and then used with flatMapValues()
in RDD transformation.
# best to ship this function to all executors for optimum performance
def get_n_from_group(group, num_recs):
"""
get `N` number of sample records
"""
res = []
i = 0
for rec in group:
res.append(rec)
i = i + 1
if i == num_recs:
break
return res
rdd = sdf.rdd. \
groupBy(lambda x: x.bvdidnumber). \
flatMapValues(lambda k: get_n_from_group(k, 2)) # 2 records only
top_n_sdf = spark.createDataFrame(rdd.values(), schema=sdf.schema)
top_n_sdf.show()
# +-----------+-------+--------+----+
# |bvdidnumber|dt_year|dt_rfrnc|goal|
# +-----------+-------+--------+----+
# | 1| 2020| 202006| 0|
# | 1| 2020| 202012| 1|
# | 2| 2020| 202006| 0|
# | 2| 2020| 202012| 0|
# +-----------+-------+--------+----+
get TopN of all groups after group by using Spark DataFrame
You can use rank
window function as follows
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}
val n: Int = ???
// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))
// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)
If you don't care about ties then you can replace rank
with row_number
pyspark: how to group N records in a spark dataframe
You may try the following:
NB. I've used the following imports.
from pyspark.sql import functions as F
from pyspark.sql import Window
1. We need a column that can be used to split your data in 500 record batches
(Recommended) We can create a pseudo column to achieve this with row_number
df = df.withColumn("group_num",(F.row_number().over(Window.orderBy("row_id"))-1) % 500 )
otherwise, if row_id
starting at 1
is consistently increasing in the 5 million records, we may use that
df = df.withColumn("group_num",(F.col("row_id")-1) % 500 )
or in that odd chance that the column "last_updated":"09-09-2021T01:03:04.44Z"
is unique to each batch of 500 records
df = df.withColumn("group_num",F.col("last_updated"))
2. We will transform your dataset by grouping by the group_num
df = (
df.groupBy("group_num")
.agg(
F.collect_list(
F.expr("struct(row_id,col1,col2)")
).alias("entries")
)
.withColumn("last_updated",F.lit(datetime.now())))
.drop("group_num")
)
NB. If you would like to include all columns you may use F.expr("struct(*)")
instead of F.expr("struct(row_id,col1,col2)")
.
3. Finally you can write to your output/destination with the option .option("maxRecordsPerFile",1)
since each row now stores at most 500 entries
Eg.
df.write.format("json").option("maxRecordsPerFile",1).save("<your intended path here>")
Let me know if this works for you
Getting Top N items per group in pySpark
If this is working it is probably better to post this to Code Review.
Just as an exercise I did this without the Counter but largely you are just replicating the same functionality.
- Count each occurrence of (
cat
,term
) - Group by
cat
- Sort the values based on Count and slice to number of terms (
2
)
Code:
from operator import add
(sample.select('cat', sf.explode('terms'))
.rdd
.map(lambda x: (x, 1))
.reduceByKey(add)
.groupBy(lambda x: x[0][0])
.mapValues(lambda x: [r[1] for r, _ in sorted(x, key=lambda a: -a[1])[:2]])
.collect())
Output:
[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])]
PySpark, top for DataFrame
First let's define a function to generate test data:
import numpy as np
def sample_df(num_records):
def data():
np.random.seed(42)
while True:
yield int(np.random.normal(100., 80.))
data_iter = iter(data())
df = sc.parallelize((
(i, next(data_iter)) for i in range(int(num_records))
)).toDF(('index', 'key_col'))
return df
sample_df(1e3).show(n=5)
+-----+-------+
|index|key_col|
+-----+-------+
| 0| 139|
| 1| 88|
| 2| 151|
| 3| 221|
| 4| 81|
+-----+-------+
only showing top 5 rows
Now, let's propose three different ways to calculate TopK:
from pyspark.sql import Window
from pyspark.sql import functions
def top_df_0(df, key_col, K):
"""
Using window functions. Handles ties OK.
"""
window = Window.orderBy(functions.col(key_col).desc())
return (df
.withColumn("rank", functions.rank().over(window))
.filter(functions.col('rank') <= K)
.drop('rank'))
def top_df_1(df, key_col, K):
"""
Using limit(K). Does NOT handle ties appropriately.
"""
return df.orderBy(functions.col(key_col).desc()).limit(K)
def top_df_2(df, key_col, K):
"""
Using limit(k) and then filtering. Handles ties OK."
"""
num_records = df.count()
value_at_k_rank = (df
.orderBy(functions.col(key_col).desc())
.limit(k)
.select(functions.min(key_col).alias('min'))
.first()['min'])
return df.filter(df[key_col] >= value_at_k_rank)
The function called top_df_1
is similar to the one you originally implemented. The reason it gives you non-deterministic behavior is because it cannot handle ties nicely. This may be an OK thing to do if you have lots of data and are only interested in an approximate answer for the sake of performance.
Finally, let's benchmark
For benchmarking use a Spark DF with 4 million entries and define a convenience function:
NUM_RECORDS = 4e6
test_df = sample_df(NUM_RECORDS).cache()
def show(func, df, key_col, K):
func(df, key_col, K).select(
functions.max(key_col),
functions.min(key_col),
functions.count(key_col)
).show()
Let's see the verdict:
%timeit show(top_df_0, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 1.62 s per loop
%timeit show(top_df_1, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 100|
+------------+------------+--------------+
1 loops, best of 3: 252 ms per loop
%timeit show(top_df_2, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 725 ms per loop
(Note that top_df_0
and top_df_2
have 108 entries in the top 100. This is due to the presence of tied entries for the 100th best. The top_df_1
implementation is ignoring the tied entries.).
The bottom line
If you want an exact answer go with top_df_2
(it is about 2x better than top_df_0
). If you want another x2 in performance and are OK with an approximate answer go with top_df_1
.
PySpark - Selecting all rows within each group
There is probably many ways to achieve this, but one way is to use Window. With Window
you can partition your data on one or more columns (in your case sale_date
) and on top of that you can order the data within each partition by a specific column (in your case descending on sale
, such that latest sale is first). So:
from pyspark.sql.window import Window
from pyspark.sql.functions import desc
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))
What you can then do is to apply this Window
on your DataFrame and apply one out of many Window-functions. One of the functions you can apply is row_number which for each partition, adds a row number to each row based on your orderBy
. Like this:
from pyspark.sql.functions import row_number
df_out = df.withColumn("row_number",row_number().over(my_window))
Which will result in that the last sale for each date will have row_number = 1
. If you then filter on row_number=1
you will get the last sale for each group.
So, the full code:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, desc, col
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))
df_out = (
df
.withColumn("row_number",row_number().over(my_window))
.filter(col("row_number") == 1)
.drop("row_number")
)
Related Topics
How to Check Whether All Elements of Array Are in Between Two Values
How to Properly Setup Pipenv in Pycharm
Pandas: Sum Dataframe Rows for Given Columns
Importerror: No Module Named Bs4 (Beautifulsoup)
Numpy Distance Calculations of Different Shaped Arrays
How to Check If a String Contains 2 of the Same Character
Add Excel File Attachment When Sending Python Email
How to Clean \Xc2\Xa0 \Xc2\Xa0..... in Text Data
How to Change Milliseconds to Seconds in Python
How to Read Pdf Files One by One from a Folder in Python
Pythonically Add Header to a CSV File
Python: How to Calculate the Average Word Length in a Sentence Using the .Split Command
Running Two Python Scripts With Bash File
Output the Same Amount of Rows as Asterisks Using For-Loop
How to Get String Objects Instead of Unicode from Json