How to use Case-When expressions in Spark
Problem
When working with Spark, sometimes you need to run a custom logic depending on the value of an attribute on each row of a Dataframe.
Solution
For this situation, you can safely use the “Case When” functionality that spark provides. It will allow you to use a specific value when a certain condition is met. Here you can find some examples:
Scala API - Dataframe Syntax
import org.apache.spark.sql.functions.{when, col, lit}
df.withColumn("output",
when(col("col1") < 10, lit("Small"))
.when(col("col1") === 10, lit("Medium"))
.otherwise(lit("Big")))
.show()
+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+
Spark SQL Syntax
SELECT col1,
CASE
WHEN col1 < 10 THEN 'Small'
WHEN col1 == 10 THEN 'Medium'
ELSE 'Big'
END as output
FROM table
---
+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+
Pyspark Syntax
from pyspark.sql.functions import when
df.select(
when(df["col1"] < 10 , "Small")
.when(df["col1"] == 10 , "Medium")
.otherwise("Big").alias("output")
).show()
+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+
Source
https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/functions$.html