Skip to content

Commit

Permalink
#678 Add metadata to the target schema when converting integral data …
Browse files Browse the repository at this point in the history
…types to decimals.
  • Loading branch information
yruslan committed Jun 3, 2024
1 parent b068a90 commit aa95886
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,60 @@ object SparkUtils extends Logging {
df.select(fields.toSeq: _*)
}

/**
* Copies metadata from one schema to another as long as names and data types are the same.
*
* @param schemaFrom Schema to copy metadata from.
* @param schemaTo Schema to copy metadata to.
* @return Same schema as schemaTo with metadata from schemaFrom.
*/
def copyMetadata(schemaFrom: StructType, schemaTo: StructType): StructType = {
@tailrec
def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = {
ar.elementType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] =>
val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
val newDataType = StructType(copyMetadata(innerStructFrom, st).fields)
ArrayType(newDataType, ar.containsNull)
case at: ArrayType =>
processArray(at, fieldFrom, fieldTo)
case p =>
ArrayType(p, ar.containsNull)
}
}

val fieldsMap = schemaFrom.fields.map(f => (f.name, f)).toMap

val newFields: Array[StructField] = schemaTo.fields.map { fieldTo =>
fieldsMap.get(fieldTo.name) match {
case Some(fieldFrom) =>
fieldTo.dataType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields)
fieldTo.copy(dataType = newDataType, metadata = fieldFrom.metadata)
case at: ArrayType =>
val newType = processArray(at, fieldFrom, fieldTo)
fieldTo.copy(dataType = newType, metadata = fieldFrom.metadata)
case _ =>
fieldTo.copy(metadata = fieldFrom.metadata)
}
case None =>
fieldTo
}
}

StructType(newFields)
}

/**
* Allows mapping every primitive field in a dataframe with a Spark expression.
*
* The metadata of the original schema is retained.
*
* @param df The dataframe to map.
* @param f The function to apply to each primitive field.
* @return The new dataframe with the mapping applied.
*/
def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = {
def mapField(column: Column, field: StructField): Column = {
field.dataType match {
Expand Down Expand Up @@ -207,7 +261,10 @@ object SparkUtils extends Logging {
}

val columns = df.schema.fields.map(f => mapField(col(f.name), f))
df.select(columns: _*)
val newDf = df.select(columns: _*)
val newSchema = copyMetadata(df.schema, newDf.schema)

df.sparkSession.createDataFrame(newDf.rdd, newSchema)
}

def covertIntegralToDecimal(df: DataFrame): DataFrame = {
Expand Down Expand Up @@ -325,7 +382,7 @@ object SparkUtils extends Logging {
val fileSystem = FileSystem.get(conf)
val hdfsBlockSize = HDFSUtils.getHDFSDefaultBlockSizeMB(fileSystem)
hdfsBlockSize match {
case None => logger.info(s"Unable to get HDFS default block size.")
case None => logger.info(s"Unable to get HDFS default block size.")
case Some(size) => logger.info(s"HDFS default block size = $size MB.")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt

assert(actualDf.schema.fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)

compareText(actualSchema, expectedSchema)
}
Expand Down

0 comments on commit aa95886

Please sign in to comment.