Querying Spark SQL DataFrame with complex types
It depends on a type of the column. Lets start with some dummy data:
import org.apache.spark.sql.functions.{udf, lit}
import scala.util.Try
case class SubRecord(x: Int)
case class ArrayElement(foo: String, bar: Int, vals: Array[Double])
case class Record(
an_array: Array[Int], a_map: Map[String, String],
a_struct: SubRecord, an_array_of_structs: Array[ArrayElement])
val df = sc.parallelize(Seq(
Record(Array(1, 2, 3), Map("foo" -> "bar"), SubRecord(1),
Array(
ArrayElement("foo", 1, Array(1.0, 2.0, 2.0)),
ArrayElement("bar", 2, Array(3.0, 4.0, 5.0)))),
Record(Array(4, 5, 6), Map("foz" -> "baz"), SubRecord(2),
Array(ArrayElement("foz", 3, Array(5.0, 6.0)),
ArrayElement("baz", 4, Array(7.0, 8.0))))
)).toDF
df.registerTempTable("df")
df.printSchema
// root
// |-- an_array: array (nullable = true)
// | |-- element: integer (containsNull = false)
// |-- a_map: map (nullable = true)
// | |-- key: string
// | |-- value: string (valueContainsNull = true)
// |-- a_struct: struct (nullable = true)
// | |-- x: integer (nullable = false)
// |-- an_array_of_structs: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- foo: string (nullable = true)
// | | |-- bar: integer (nullable = false)
// | | |-- vals: array (nullable = true)
// | | | |-- element: double (containsNull = false)
array (
ArrayType
) columns:Column.getItem
methoddf.select($"an_array".getItem(1)).show
// +-----------+
// |an_array[1]|
// +-----------+
// | 2|
// | 5|
// +-----------+Hive brackets syntax:
sqlContext.sql("SELECT an_array[1] FROM df").show
// +---+
// |_c0|
// +---+
// | 2|
// | 5|
// +---+an UDF
val get_ith = udf((xs: Seq[Int], i: Int) => Try(xs(i)).toOption)
df.select(get_ith($"an_array", lit(1))).show
// +---------------+
// |UDF(an_array,1)|
// +---------------+
// | 2|
// | 5|
// +---------------+Additionally to the methods listed above Spark supports a growing list of built-in functions operating on complex types. Notable examples include higher order functions like
transform
(SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):df.selectExpr("transform(an_array, x -> x + 1) an_array_inc").show
// +------------+
// |an_array_inc|
// +------------+
// | [2, 3, 4]|
// | [5, 6, 7]|
// +------------+
import org.apache.spark.sql.functions.transform
df.select(transform($"an_array", x => x + 1) as "an_array_inc").show
// +------------+
// |an_array_inc|
// +------------+
// | [2, 3, 4]|
// | [5, 6, 7]|
// +------------+filter
(SQL 2.4+, Scala 3.0+, Python / SparkR 3.1+)df.selectExpr("filter(an_array, x -> x % 2 == 0) an_array_even").show
// +-------------+
// |an_array_even|
// +-------------+
// | [2]|
// | [4, 6]|
// +-------------+
import org.apache.spark.sql.functions.filter
df.select(filter($"an_array", x => x % 2 === 0) as "an_array_even").show
// +-------------+
// |an_array_even|
// +-------------+
// | [2]|
// | [4, 6]|
// +-------------+aggregate
(SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):df.selectExpr("aggregate(an_array, 0, (acc, x) -> acc + x, acc -> acc) an_array_sum").show
// +------------+
// |an_array_sum|
// +------------+
// | 6|
// | 15|
// +------------+
import org.apache.spark.sql.functions.aggregate
df.select(aggregate($"an_array", lit(0), (x, y) => x + y) as "an_array_sum").show
// +------------+
// |an_array_sum|
// +------------+
// | 6|
// | 15|
// +------------+array processing functions (
array_*
) likearray_distinct
(2.4+):import org.apache.spark.sql.functions.array_distinct
df.select(array_distinct($"an_array_of_structs.vals"(0))).show
// +-------------------------------------------+
// |array_distinct(an_array_of_structs.vals[0])|
// +-------------------------------------------+
// | [1.0, 2.0]|
// | [5.0, 6.0]|
// +-------------------------------------------+array_max
(array_min
, 2.4+):import org.apache.spark.sql.functions.array_max
df.select(array_max($"an_array")).show
// +-------------------+
// |array_max(an_array)|
// +-------------------+
// | 3|
// | 6|
// +-------------------+flatten
(2.4+)import org.apache.spark.sql.functions.flatten
df.select(flatten($"an_array_of_structs.vals")).show
// +---------------------------------+
// |flatten(an_array_of_structs.vals)|
// +---------------------------------+
// | [1.0, 2.0, 2.0, 3...|
// | [5.0, 6.0, 7.0, 8.0]|
// +---------------------------------+arrays_zip
(2.4+):import org.apache.spark.sql.functions.arrays_zip
df.select(arrays_zip($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show(false)
// +--------------------------------------------------------------------+
// |arrays_zip(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
// +--------------------------------------------------------------------+
// |[[1.0, 3.0], [2.0, 4.0], [2.0, 5.0]] |
// |[[5.0, 7.0], [6.0, 8.0]] |
// +--------------------------------------------------------------------+array_union
(2.4+):import org.apache.spark.sql.functions.array_union
df.select(array_union($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show
// +---------------------------------------------------------------------+
// |array_union(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
// +---------------------------------------------------------------------+
// | [1.0, 2.0, 3.0, 4...|
// | [5.0, 6.0, 7.0, 8.0]|
// +---------------------------------------------------------------------+slice
(2.4+):import org.apache.spark.sql.functions.slice
df.select(slice($"an_array", 2, 2)).show
// +---------------------+
// |slice(an_array, 2, 2)|
// +---------------------+
// | [2, 3]|
// | [5, 6]|
// +---------------------+
map (
MapType
) columnsusing
Column.getField
method:df.select($"a_map".getField("foo")).show
// +----------+
// |a_map[foo]|
// +----------+
// | bar|
// | null|
// +----------+using Hive brackets syntax:
sqlContext.sql("SELECT a_map['foz'] FROM df").show
// +----+
// | _c0|
// +----+
// |null|
// | baz|
// +----+using a full path with dot syntax:
df.select($"a_map.foo").show
// +----+
// | foo|
// +----+
// | bar|
// |null|
// +----+using an UDF
val get_field = udf((kvs: Map[String, String], k: String) => kvs.get(k))
df.select(get_field($"a_map", lit("foo"))).show
// +--------------+
// |UDF(a_map,foo)|
// +--------------+
// | bar|
// | null|
// +--------------+Growing number of
map_*
functions likemap_keys
(2.3+)import org.apache.spark.sql.functions.map_keys
df.select(map_keys($"a_map")).show
// +---------------+
// |map_keys(a_map)|
// +---------------+
// | [foo]|
// | [foz]|
// +---------------+or
map_values
(2.3+)import org.apache.spark.sql.functions.map_values
df.select(map_values($"a_map")).show
// +-----------------+
// |map_values(a_map)|
// +-----------------+
// | [bar]|
// | [baz]|
// +-----------------+
Please check SPARK-23899 for a detailed list.
struct (
StructType
) columns using full path with dot syntax:with DataFrame API
df.select($"a_struct.x").show
// +---+
// | x|
// +---+
// | 1|
// | 2|
// +---+with raw SQL
sqlContext.sql("SELECT a_struct.x FROM df").show
// +---+
// | x|
// +---+
// | 1|
// | 2|
// +---+
fields inside array of
structs
can be accessed using dot-syntax, names and standardColumn
methods:df.select($"an_array_of_structs.foo").show
// +----------+
// | foo|
// +----------+
// |[foo, bar]|
// |[foz, baz]|
// +----------+
sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show
// +---+
// |_c0|
// +---+
// |foo|
// |foz|
// +---+
df.select($"an_array_of_structs.vals".getItem(1).getItem(1)).show
// +------------------------------+
// |an_array_of_structs.vals[1][1]|
// +------------------------------+
// | 4.0|
// | 8.0|
// +------------------------------+user defined types (UDTs) fields can be accessed using UDFs. See Spark SQL referencing attributes of UDT for details.
Notes:
- depending on a Spark version some of these methods can be available only with
HiveContext
. UDFs should work independent of version with both standardSQLContext
andHiveContext
. generally speaking nested values are a second class citizens. Not all typical operations are supported on nested fields. Depending on a context it could be better to flatten the schema and / or explode collections
df.select(explode($"an_array_of_structs")).show
// +--------------------+
// | col|
// +--------------------+
// |[foo,1,WrappedArr...|
// |[bar,2,WrappedArr...|
// |[foz,3,WrappedArr...|
// |[baz,4,WrappedArr...|
// +--------------------+Dot syntax can be combined with wildcard character (
*
) to select (possibly multiple) fields without specifying names explicitly:df.select($"a_struct.*").show
// +---+
// | x|
// +---+
// | 1|
// | 2|
// +---+JSON columns can be queried using
get_json_object
andfrom_json
functions. See How to query JSON data column using Spark DataFrames? for details.
Scala compare dataframe complex array type field
Coming in the reverse way...
val s = Seq(Array(1024, 100001D), Array(1, -1D)).toDS().toDF("myList")
println(s.schema)
s.printSchema
s.show
Your schema is like below... DoubleType
is coming since these 100001D and -1D are double.
StructType(StructField(myList,ArrayType(DoubleType,false),true))
Output you needed:
root
|-- myList: array (nullable = true)
| |-- element: double (containsNull = false)
+------------------+
| myList|
+------------------+
|[1024.0, 100001.0]|
| [1.0, -1.0]|
+------------------+
Or this way also you can do that.
case class MyObject(a:Int , b:Double)
val s = Seq(MyObject(1024, 100001D), MyObject(1, -1D)).toDS()
.select(struct($"a",$"b").as[MyObject] as "myList")
println(s.schema)
s.printSchema
s.show
Result:
//schema :
StructType(StructField(myList,StructType(StructField(a,IntegerType,false), StructField(b,DoubleType,false)),false))
root
|-- myList: struct (nullable = false)
| |-- a: integer (nullable = false)
| |-- b: double (nullable = false)
+----------------+
| myList|
+----------------+
|[1024, 100001.0]|
| [1, -1.0]|
+----------------+
Spark / Scala how to write a complex query that iterates through a dataframe and adds a column
You can use a udf
function as below for generating the array of string dates in between START_DAY_ID
and END_DAY_ID
columns
import org.apache.spark.sql.functions._
def days_in_range = udf((start_day: String, diff:Int)=>{
val format = new SimpleDateFormat("yyyyMMdd")
val calStart = Calendar.getInstance
val startDate = calStart.setTime(format.parse(start_day))
val listBuffer = new ListBuffer[String]
for(day <- 1 until diff) {
calStart.add(Calendar.DATE, 1)
listBuffer.append(format.format(calStart.getTime))
}
listBuffer
})
diff
Integer is derived using dateDiff
inbuilt function while calling the udf
function as
df1.select(col("COLLECTION"), days_in_range(col("START_DAY_ID"), datediff(to_date(col("END_DAY_ID"), "yyyyMMdd"), to_date(col("START_DAY_ID"), "yyyyMMdd"))).as("days"))
.show()
which should give you
+----------+--------------------+
|COLLECTION| days|
+----------+--------------------+
| HIVER_19|[20190903, 201909...|
| ETE_19|[20181203, 201812...|
+----------+--------------------+
I hope the answer is helpful
Order Spark SQL Dataframe with nested values / complex data types
I'd reverse the order of the struct
and aggregate with max
:
val result = df
.groupBy(col("Id"))
.agg(
collect_list(struct(col("Date"), col("Paid"))) as "UserPayments",
max(struct(col("Paid"), col("Date"))) as "MaxPayment"
)
result.show
// +---+--------------------+---------------+
// | Id| UserPayments| MaxPayment|
// +---+--------------------+---------------+
// | yc|[[07:00 AM,16.6],...|[16.6,07:00 AM]|
// | mk|[[10:00 AM,8.6], ...|[12.6,06:00 AM]|
// +---+--------------------+---------------+
You can later flatten the struct
:
result.select($"id", $"UserPayments", $"MaxPayment.*").show
// +---+--------------------+----+--------+
// | id| UserPayments|Paid| Date|
// +---+--------------------+----+--------+
// | yc|[[07:00 AM,16.6],...|16.6|07:00 AM|
// | mk|[[10:00 AM,8.6], ...|12.6|06:00 AM|
// +---+--------------------+----+--------+
Same way you can sort_array
of reordered structs
df
.groupBy(col("Id"))
.agg(
sort_array(collect_list(struct(col("Paid"), col("Date")))) as "UserPayments"
)
.show(false)
// +---+-------------------------------------------------+
// |Id |UserPayments |
// +---+-------------------------------------------------+
// |yc |[[2.6,09:00 AM], [16.6,07:00 AM]] |
// |mk |[[5.6,11:00 AM], [8.6,10:00 AM], [12.6,06:00 AM]]|
// +---+-------------------------------------------------+
Finally:
This is a naive and straight-forward approach, but I have concerns in terms of correctness. Will the list really be ordered globally or only within a partition?
Data will be ordered globally, but the order will be destroyed by groupBy
so this is is not a solution, and can work only accidentally.
repartition
(by id
) and sortWithinPartitions
(by id
and Paid
) should be reliable replacement.
Related Topics
Need to Return Two Sets of Data With Two Different Where Clauses
Computed/Calculated/Virtual/Derived Columns in Postgresql
Join Tables With Sum Issue in MySQL
Synchronizing Client-Server Databases
Simplest Way to Do a Recursive Self-Join
Subquery Using Exists 1 or Exists *
Calculate Age in MySQL (Innodb)
How to Concatenate Multiple MySQL Rows into One Field
Best Approach to Remove Time Part of Datetime in SQL Server
Recommended SQL Database Design For Tags or Tagging
Dynamic Alternative to Pivot With Case and Group By
Tsql Pivot Without Aggregate Function
How to Temporarily Disable a Foreign Key Constraint in MySQL