Pythian Blog: Technical Track

Spark Scala UDF primitive type bug

I was working on an instrumentation framework for Scala UDFs in Spark when I noticed a subtle difference in the execution plan depending on whether I used wrappers or not. It looked like some code was added or was not predicate to check nulls:

val f = (x: Long) => x
val udf0 = udf(f)
...
.withColumn("udf0", udf0(...))
...
// in explain if (isnull(...)) null else UDF(...) AS udf0#111L

 vs

def identity[T, U](f: T => U): T => U = (t: T) => f(t)
val udf1 = udf(identity(f))
...
.withColumn("udf1", udf1(...))
...
// in explain UDF(...) AS udf1#115L

Quick doc checking sheds light on the special case of UDFs based on functions with primitive input arguments:

Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.

In my case I have no really changed types, but I used high order function, something like this:

val f = (x: Long) => x

def identity[T, U](f: T => U): T => U = (t: T) => f(t)

val udf0 = udf(f)

val udf1 = udf(identity(f))

 

Both udf0 and udf1 look pretty the same at first sight:

scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
identity: [T, U](f: T => U)T => U

scala> val udf0 = udf(f)
udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

scala> val udf1 = udf(identity(f))
udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

 

While during the execution they worked differently for null input:

scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List()))

scala> spark.range(5).toDF()
  .withColumn("udf0", udf0(getNull()))
  .withColumn("udf1", udf1(getNull()))
  .show()
+---+----+----+
| id|udf0|udf1|
+---+----+----+
| 0|null| 0|
| 1|null| 0|
| 2|null| 0|
| 3|null| 0|
| 4|null| 0|
+---+----+----+

scala> spark.range(5).toDF()
  .withColumn("udf0", udf0(getNull()))
  .withColumn("udf1", udf1(getNull()))
  .explain()
== Physical Plan ==
*Project [id#106L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#111L, UDF(UDF()) AS udf1#115L]
+- *Range (0, 5, step=1, splits=2)

 

I tracked why this happen through Spark sources:

  • udf
    def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {

      val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption

      UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes)

    }
  • UserDefinedFunction
    case class UserDefinedFunction protected[sql] (

      f: AnyRef, 

      dataType: DataType, 

      inputTypes: Option[Seq[DataType]]

    ) { 

      // Method definition within the case class

      def apply(exprs: Column*): Column = {

        Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil)))

      } 

    }
  • ScalaUDF
    case class ScalaUDF(

      function: AnyRef, 

      dataType: DataType, 

      children: Seq[Expression], 

      inputTypes: Seq[DataType] = Nil, 

      udfName: Option[String] = None

    ) extends Expression with ImplicitCastInputTypes with NonSQLExpression { 

      // Additional implementation details would go here...

    }
  • HandleNullInputsForUDF from Catalyst Analyzer (TODO from this piece explained the fact of mess with nullability, it simply doesn't work when I would expect it does):
    /**

     * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the

     * null check. When user defines a UDF with primitive parameters, there is no way to tell if the

     * primitive parameter is null or not, so here we assume the primitive input is null-propagatable

     * and we should return null if the input is null.

     */

    object HandleNullInputsForUDF extends Rule[LogicalPlan] {

      override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {

        case p if !p.resolved => p // Skip unresolved nodes.

        case p => p.transformExpressionsUp {

          case udf @ ScalaUDF(func, _, inputs, _, _) =>

            val parameterTypes = ScalaReflection.getParameterTypes(func)

            assert(parameterTypes.length == inputs.length)

            val inputsNullCheck = parameterTypes.zip(inputs)

              // TODO: skip null handling for not-nullable primitive inputs after we can completely

              // trust the `nullable` information.

              // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }

              .filter { case (cls, _) => cls.isPrimitive }

              .map { case (_, expr) => IsNull(expr) }

              .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))

            inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)

        }

      }

    }
  • And final piece
    def getParameterTypes(func: AnyRef): Seq[Class[_]] = { val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) assert(methods.length == 1) methods.head.getParameterTypes }

As you can see, it uses java runtime class information, and it's no surprise "isPrimitive" does't work the way we would expect due to the type erasure. In this case that is:

scala> ScalaReflection.getParameterTypes(f)
res1: Seq[Class[_]] = WrappedArray(long)

scala> ScalaReflection.getParameterTypes(identity(f))
res2: Seq[Class[_]] = WrappedArray(class java.lang.Object)

Instead it should use TypeTag we have in udf declaration, like this:

scala> def myGetParameterTypes[T: TypeTag, U](func: T => U) = {
  typeTag[T].tpe.typeSymbol.asClass
}
myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol

scala> myGetParameterTypes(f)
res3: reflect.runtime.universe.ClassSymbol = class Long

scala> myGetParameterTypes(f).isPrimitive
res4: Boolean = true

The workaround is quite ugly though, it is to use specialization:

scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => f(t)
identity2: [T, U](f: T => U)T => U

scala> val udf2 = udf(identity2(f))
udf2: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

scala> ScalaReflection.getParameterTypes(identity2(f))
res10: Seq[Class[_]] = WrappedArray(long)

As result I submitted Spark Jira issue SPARK-23833 Be careful when using udf operating primitive types if nullable data can be passed to it. There are many possible scenarios when behavior may be different. It should be a rule that: if nullable data can be passed then you have to use boxed types or Option.

No Comments Yet

Let us know what you think

Subscribe by email