diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/parameters/CobolParametersParser.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/parameters/CobolParametersParser.scala index 1338aab4..03d6d104 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/parameters/CobolParametersParser.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/parameters/CobolParametersParser.scala @@ -428,6 +428,7 @@ object CobolParametersParser extends Logging { val recordLengthFieldOpt = params.get(PARAM_RECORD_LENGTH_FIELD) val isRecordSequence = Seq(FixedBlock, VariableLength, VariableBlock).contains(recordFormat) val isRecordIdGenerationEnabled = params.getOrElse(PARAM_GENERATE_RECORD_ID, "false").toBoolean + val isSegmentIdGenerationEnabled = params.contains(s"${PARAM_SEGMENT_ID_LEVEL_PREFIX}0") val fileStartOffset = params.getOrElse(PARAM_FILE_START_OFFSET, "0").toInt val fileEndOffset = params.getOrElse(PARAM_FILE_END_OFFSET, "0").toInt val varLenOccursEnabled = params.getOrElse(PARAM_VARIABLE_SIZE_OCCURS, "false").toBoolean @@ -448,6 +449,7 @@ object CobolParametersParser extends Logging { if (recordLengthFieldOpt.isDefined || isRecordSequence || isRecordIdGenerationEnabled || + isSegmentIdGenerationEnabled || fileStartOffset > 0 || fileEndOffset > 0 || hasRecordExtractor || diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index f059ecf8..160c91c1 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -179,6 +179,67 @@ object SparkUtils extends Logging { df.select(fields.toSeq: _*) } + /** + * Removes all struct nesting when possible for a given schema. + */ + def unstructSchema(schema: StructType, useShortFieldNames: Boolean = false): StructType = { + def mapFieldShort(field: StructField): Array[StructField] = { + field.dataType match { + case st: StructType => + st.fields flatMap mapFieldShort + case _ => + Array(field) + } + } + + def mapFieldLong(field: StructField, path: String): Array[StructField] = { + field.dataType match { + case st: StructType => + st.fields.flatMap(f => mapFieldLong(f, s"$path${field.name}_")) + case _ => + Array(field.copy(name = s"$path${field.name}")) + } + } + + val fields = if (useShortFieldNames) + schema.fields flatMap mapFieldShort + else + schema.fields.flatMap(f => mapFieldLong(f, "")) + + StructType(fields) + } + + /** + * Removes all struct nesting when possible for a given dataframe. + * + * Similar to `flattenSchema()`, but does not flatten arrays. + */ + def unstructDataFrame(df: DataFrame, useShortFieldNames: Boolean = false): DataFrame = { + def mapFieldShort(column: Column, field: StructField): Array[Column] = { + field.dataType match { + case st: StructType => + st.fields.flatMap(f => mapFieldShort(column.getField(f.name), f)) + case _ => + Array(column.as(field.name, field.metadata)) + } + } + + def mapFieldLong(column: Column, field: StructField, path: String): Array[Column] = { + field.dataType match { + case st: StructType => + st.fields.flatMap(f => mapFieldLong(column.getField(f.name), f, s"$path${field.name}_")) + case _ => + Array(column.as(s"$path${field.name}", field.metadata)) + } + } + + val columns = if (useShortFieldNames) + df.schema.fields.flatMap(f => mapFieldShort(col(f.name), f)) + else + df.schema.fields.flatMap(f => mapFieldLong(col(f.name), f, "")) + df.select(columns: _*) + } + /** * Copies metadata from one schema to another as long as names and data types are the same. * @@ -237,7 +298,7 @@ object SparkUtils extends Logging { def mapField(column: Column, field: StructField): Column = { field.dataType match { case st: StructType => - val columns = st.fields.map(f => mapField(column.getField(field.name), f)) + val columns = st.fields.map(f => mapField(column.getField(f.name), f)) struct(columns: _*).as(field.name, field.metadata) case ar: ArrayType => mapArray(ar, column, field.name).as(field.name, field.metadata) diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 3f1ee467..3da8d76f 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -429,6 +429,180 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt } } + test("unstructDataFrame() and unstructSchema() should flatten a schema and the dataframe with short names") { + val copyBook: String = + """ 01 RECORD. + | 02 COUNT PIC 9(1). + | 02 GROUP1. + | 03 INNER-COUNT PIC S9(1). + | 03 INNER-GROUP OCCURS 3 TIMES. + | 04 FIELD PIC 9. + | 02 GROUP2. + | 03 INNER-COUNT PIC S9(1). + | 03 INNER-NUM PIC 9 OCCURS 3 TIMES. + |""".stripMargin + + val expectedSchema = + """|root + | |-- COUNT: integer (nullable = true) + | |-- INNER_COUNT: integer (nullable = true) + | |-- INNER_GROUP: array (nullable = true) + | | |-- element: struct (containsNull = true) + | | | |-- FIELD: integer (nullable = true) + | |-- INNER_COUNT: integer (nullable = true) + | |-- INNER_NUM: array (nullable = true) + | | |-- element: integer (containsNull = true) + |""".stripMargin + + val expectedData = + """[ { + | "COUNT" : 2, + | "INNER_COUNT" : 1, + | "INNER_GROUP" : [ { + | "FIELD" : 4 + | }, { + | "FIELD" : 5 + | }, { + | "FIELD" : 6 + | } ], + | "INNER_NUM" : [ 7, 8, 9 ] + |}, { + | "COUNT" : 3, + | "INNER_COUNT" : 2, + | "INNER_GROUP" : [ { + | "FIELD" : 7 + | }, { + | "FIELD" : 8 + | }, { + | "FIELD" : 9 + | } ], + | "INNER_NUM" : [ 4, 5, 6 ] + |} ] + |""".stripMargin + + withTempTextFile("flatten", "test", StandardCharsets.UTF_8, "224561789\n347892456\n") { filePath => + val df = spark.read + .format("cobol") + .option("copybook_contents", copyBook) + .option("pedantic", "true") + .option("record_format", "D") + .option("metadata", "extended") + .load(filePath) + + val actualDf = SparkUtils.unstructDataFrame(df, useShortFieldNames = true) + val actualSchema = actualDf.schema.treeString + val actualSchemaOnly = SparkUtils.unstructSchema(df.schema, useShortFieldNames = true) + val actualSchema2 = actualSchemaOnly.treeString + + compareText(actualSchema, expectedSchema) + compareText(actualSchema2, expectedSchema) + + val actualData = SparkUtils.prettyJSON(actualDf.orderBy("COUNT").toJSON.collect().mkString("[", ", ", "]")) + + compareText(actualData, expectedData) + + assert(actualDf.schema.fields.head.metadata.json.nonEmpty) + assert(actualDf.schema.fields(1).metadata.json.nonEmpty) + assert(actualDf.schema.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) + assert(actualDf.schema.fields(3).metadata.json.nonEmpty) + assert(actualDf.schema.fields(4).metadata.json.nonEmpty) + + assert(actualSchemaOnly.fields.head.metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(1).metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(3).metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(4).metadata.json.nonEmpty) + } + } + + test("unstructDataFrame() and unstructSchema() should flatten a schema and the dataframe with long names") { + val copyBook: String = + """ 01 RECORD. + | 02 COUNT PIC 9(1). + | 02 GROUP1. + | 03 INNER-COUNT PIC S9(1). + | 03 INNER-GROUP OCCURS 3 TIMES. + | 04 FIELD PIC 9. + | 02 GROUP2. + | 03 INNER-COUNT PIC S9(1). + | 03 INNER-NUM PIC 9 OCCURS 3 TIMES. + |""".stripMargin + + val expectedSchema = + """|root + | |-- COUNT: integer (nullable = true) + | |-- GROUP1_INNER_COUNT: integer (nullable = true) + | |-- GROUP1_INNER_GROUP: array (nullable = true) + | | |-- element: struct (containsNull = true) + | | | |-- FIELD: integer (nullable = true) + | |-- GROUP2_INNER_COUNT: integer (nullable = true) + | |-- GROUP2_INNER_NUM: array (nullable = true) + | | |-- element: integer (containsNull = true) + |""".stripMargin + + val expectedData = + """[ { + | "COUNT" : 2, + | "GROUP1_INNER_COUNT" : 2, + | "GROUP1_INNER_GROUP" : [ { + | "FIELD" : 4 + | }, { + | "FIELD" : 5 + | }, { + | "FIELD" : 6 + | } ], + | "GROUP2_INNER_COUNT" : 1, + | "GROUP2_INNER_NUM" : [ 7, 8, 9 ] + |}, { + | "COUNT" : 3, + | "GROUP1_INNER_COUNT" : 4, + | "GROUP1_INNER_GROUP" : [ { + | "FIELD" : 7 + | }, { + | "FIELD" : 8 + | }, { + | "FIELD" : 9 + | } ], + | "GROUP2_INNER_COUNT" : 2, + | "GROUP2_INNER_NUM" : [ 4, 5, 6 ] + |} ] + |""".stripMargin + + withTempTextFile("flatten", "test", StandardCharsets.UTF_8, "224561789\n347892456\n") { filePath => + val df = spark.read + .format("cobol") + .option("copybook_contents", copyBook) + .option("pedantic", "true") + .option("record_format", "D") + .option("metadata", "extended") + .load(filePath) + + val actualDf = SparkUtils.unstructDataFrame(df) + val actualSchema = actualDf.schema.treeString + val actualSchemaOnly = SparkUtils.unstructSchema(df.schema) + val actualSchema2 = actualSchemaOnly.treeString + + compareText(actualSchema, expectedSchema) + compareText(actualSchema2, expectedSchema) + + val actualData = SparkUtils.prettyJSON(actualDf.orderBy("COUNT").toJSON.collect().mkString("[", ", ", "]")) + + compareText(actualData, expectedData) + + assert(actualDf.schema.fields.head.metadata.json.nonEmpty) + assert(actualDf.schema.fields(1).metadata.json.nonEmpty) + assert(actualDf.schema.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) + assert(actualDf.schema.fields(3).metadata.json.nonEmpty) + assert(actualDf.schema.fields(4).metadata.json.nonEmpty) + + assert(actualSchemaOnly.fields.head.metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(1).metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(3).metadata.json.nonEmpty) + assert(actualSchemaOnly.fields(4).metadata.json.nonEmpty) + } + } + test("Integral to decimal conversion for complex schema") { val expectedSchema = """|root