Skip to content

Spark – How to apply a function to multiple columns on DataFrame?

let’s see that you have a spark dataframe and you want to apply a function to multiple columns. One way is to use WithColumn multiple times. However, that’s good when you have only few columns and you know column names in advance. Otherwise, it’s tedious and error-some.

So let’s see how to do that

val df=List(("$100", "$90", "$10")).toDF("selling_price", "market_price", "profit")
+-------------+------------+------+
|selling_price|market_price|profit|
+-------------+------------+------+
|         $100|         $90|   $10|
+-------------+------------+------+

Let’s consider you have a spark dataframe as above with more than 50 such columns, and you want to remove $ character and convert datatype to Decimal. Rather than writing 50 lines of code, you can do that using fold in less than 5 lines.

First, Create a list with new column name (yes, you need new column name) and the function you want to apply. I just added _new to existing column name so it’s easier to rename later.
And next thing you need is to utilize foldLeft method to recursively function from a list to given dataframe.

import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.types.DataTypes._
val df=List(("$100", "$90", "$10")).toDF("selling_price", "market_price", "profit")
df.show
val operations =  ListBuffer[(String, org.apache.spark.sql.Column)]()
val colNames = df.columns
val DecimalType = createDecimalType(10, 4)
colNames.foreach{colName =>
  val operation = (s"${colName}_new", regexp_replace(col(colName), lit("\$"), lit("")).cast(DecimalType))
  operations += operation
}

val dfWithNewColumns = operations.foldLeft(df) { (tempDF, listValue) =>
  tempDF.withColumn(listValue._1, listValue._2)
}

dfWithNewColumns.show

let’s see if that worked.

 
scala> dfWithNewColumns.printSchema
root
 |-- selling_price: string (nullable = true)
 |-- market_price: string (nullable = true)
 |-- profit: string (nullable = true)
 |-- selling_price_new: decimal(10,4) (nullable = true)
 |-- market_price_new: decimal(10,4) (nullable = true)
 |-- profit_new: decimal(10,4) (nullable = true)


scala> dfWithNewColumns.show
+-------------+------------+------+-----------------+----------------+----------+
|selling_price|market_price|profit|selling_price_new|market_price_new|profit_new|
+-------------+------------+------+-----------------+----------------+----------+
|         $100|         $90|   $10|         100.0000|         90.0000|   10.0000|
+-------------+------------+------+-----------------+----------------+----------+

Published inspark

Be First to Comment

Leave a Reply

Your email address will not be published. Required fields are marked *