Skip to main content

How to use Case-When expressions in Spark

Problem

When working with Spark, sometimes you need to run a custom logic depending on the value of an attribute on each row of a Dataframe.

Solution

For this situation, you can safely use the “Case When” functionality that spark provides. It will allow you to use a specific value when a certain condition is met. Here you can find some examples:

Scala API - Dataframe Syntax

import org.apache.spark.sql.functions.{when, col, lit}

df.withColumn("output",
when(col("col1") < 10, lit("Small"))
.when(col("col1") === 10, lit("Medium"))
.otherwise(lit("Big")))
.show()

+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+

Spark SQL Syntax

SELECT col1,  
CASE
WHEN col1 < 10 THEN 'Small'
WHEN col1 == 10 THEN 'Medium'
ELSE 'Big'
END as output
FROM table

---
+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+

Pyspark Syntax

from pyspark.sql.functions import when

df.select(
when(df["col1"] < 10 , "Small")
.when(df["col1"] == 10 , "Medium")
.otherwise("Big").alias("output")
).show()


+----+------+
|col1|output|
+----+------+
| 5| Small|
| 10|Medium|
| 30| Big|
+----+------+

Source

https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.when.html#pyspark.sql.functions.when

https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/functions$.html