Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten dataframe #47

Open
bjornjorgensen opened this issue Aug 21, 2022 · 2 comments
Open

Flatten dataframe #47

bjornjorgensen opened this issue Aug 21, 2022 · 2 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed

Comments

@bjornjorgensen
Copy link

Hi, can we add a function to flatten a nested dataframe?

from pyspark.sql.functions import *
from pyspark.sql.types import *


def flatten_test(df, sep="_"):
    """Returns a flattened dataframe.
    .. versionadded:: x.X.X

    Parameters
    ----------
    sep : str
        Delimiter for flatted columns. Default `_`

    Notes
    -----
    Don`t use `.` as `sep`
    It won't work on nested data frames with more than one level.
    And you will have to use `columns.name`.

    Flattening Map Types will have to find every key in the column.
    This can be slow.

    Examples
    --------

    data_mixed = [
        {
            "state": "Florida",
            "shortname": "FL",
            "info": {"governor": "Rick Scott"},
            "counties": [
                {"name": "Dade", "population": 12345},
                {"name": "Broward", "population": 40000},
                {"name": "Palm Beach", "population": 60000},
            ],
        },
        {
            "state": "Ohio",
            "shortname": "OH",
            "info": {"governor": "John Kasich"},
            "counties": [
                {"name": "Summit", "population": 1234},
                {"name": "Cuyahoga", "population": 1337},
            ],
        },
    ]

    data_mixed = spark.createDataFrame(data=data_mixed)

    data_mixed.printSchema()

    root
    |-- counties: array (nullable = true)
    |    |-- element: map (containsNull = true)
    |    |    |-- key: string
    |    |    |-- value: string (valueContainsNull = true)
    |-- info: map (nullable = true)
    |    |-- key: string
    |    |-- value: string (valueContainsNull = true)
    |-- shortname: string (nullable = true)
    |-- state: string (nullable = true)


    data_mixed_flat = flatten_test(df, sep=":")
    data_mixed_flat.printSchema()
    root
    |-- shortname: string (nullable = true)
    |-- state: string (nullable = true)
    |-- counties:name: string (nullable = true)
    |-- counties:population: string (nullable = true)
    |-- info:governor: string (nullable = true)




    data = [
        {
            "id": 1,
            "name": "Cole Volk",
            "fitness": {"height": 130, "weight": 60},
        },
        {"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
        {
            "id": 2,
            "name": "Faye Raker",
            "fitness": {"height": 130, "weight": 60},
        },
    ]


    df = spark.createDataFrame(data=data)

    df.printSchema()

    root
    |-- fitness: map (nullable = true)
    |    |-- key: string
    |    |-- value: long (valueContainsNull = true)
    |-- id: long (nullable = true)
    |-- name: string (nullable = true)

    df_flat = flatten_test(df, sep=":")

    df_flat.printSchema()

    root
    |-- id: long (nullable = true)
    |-- name: string (nullable = true)
    |-- fitness:height: long (nullable = true)
    |-- fitness:weight: long (nullable = true)

    data_struct = [
            (("James",None,"Smith"),"OH","M"),
            (("Anna","Rose",""),"NY","F"),
            (("Julia","","Williams"),"OH","F"),
            (("Maria","Anne","Jones"),"NY","M"),
            (("Jen","Mary","Brown"),"NY","M"),
            (("Mike","Mary","Williams"),"OH","M")
            ]


    schema = StructType([
        StructField('name', StructType([
            StructField('firstname', StringType(), True),
            StructField('middlename', StringType(), True),
            StructField('lastname', StringType(), True)
            ])),
        StructField('state', StringType(), True),
        StructField('gender', StringType(), True)
        ])

    df_struct = spark.createDataFrame(data = data_struct, schema = schema)

    df_struct.printSchema()

    root
    |-- name: struct (nullable = true)
    |    |-- firstname: string (nullable = true)
    |    |-- middlename: string (nullable = true)
    |    |-- lastname: string (nullable = true)
    |-- state: string (nullable = true)
    |-- gender: string (nullable = true)

    df_struct_flat = flatten_test(df_struct, sep=":")

    df_struct_flat.printSchema()

    root
    |-- state: string (nullable = true)
    |-- gender: string (nullable = true)
    |-- name:firstname: string (nullable = true)
    |-- name:middlename: string (nullable = true)
    |-- name:lastname: string (nullable = true)
    """
    # compute Complex Fields (Arrays, Structs and Maptypes) in Schema
    complex_fields = dict(
        [
            (field.name, field.dataType)
            for field in df.schema.fields
            if type(field.dataType) == ArrayType
            or type(field.dataType) == StructType
            or type(field.dataType) == MapType
        ]
    )

    while len(complex_fields) != 0:
        col_name = list(complex_fields.keys())[0]
        # print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))

        # if StructType then convert all sub element to columns.
        # i.e. flatten structs
        if type(complex_fields[col_name]) == StructType:
            expanded = [
                col(col_name + "." + k).alias(col_name + sep + k)
                for k in [n.name for n in complex_fields[col_name]]
            ]
            df = df.select("*", *expanded).drop(col_name)

        # if ArrayType then add the Array Elements as Rows using the explode function
        # i.e. explode Arrays
        elif type(complex_fields[col_name]) == ArrayType:
            df = df.withColumn(col_name, explode_outer(col_name))

        # if MapType then convert all sub element to columns.
        # i.e. flatten
        elif type(complex_fields[col_name]) == MapType:
            keys_df = df.select(explode_outer(map_keys(col(col_name)))).distinct()
            keys = list(map(lambda row: row[0], keys_df.collect()))
            key_cols = list(
                map(
                    lambda f: col(col_name).getItem(f).alias(str(col_name + sep + f)),
                    keys,
                )
            )
            drop_column_list = [col_name]
            df = df.select(
                [
                    col_name
                    for col_name in df.columns
                    if col_name not in drop_column_list
                ]
                + key_cols
            )

        # recompute remaining Complex Fields in Schema
        complex_fields = dict(
            [
                (field.name, field.dataType)
                for field in df.schema.fields
                if type(field.dataType) == ArrayType
                or type(field.dataType) == StructType
                or type(field.dataType) == MapType
            ]
        )

    return df





@MrPowers MrPowers added help wanted Extra attention is needed good first issue Good for newcomers labels Oct 17, 2022
@orcascope
Copy link
Contributor

@bjornjorgensen Certainly I see a very good piece of reusable code here. But with respect to the test helper methods provided by chispa, can you help me understand how this flatten function be used?

@MrPowers
Copy link
Owner

MrPowers commented Oct 23, 2022

Thanks for opening this issue and thanks for commenting @orcascope.

spark-daria has a flattenSchema method that does something similar in Scala.

The sister PySpark library for spark-daria is quinn. Seems like quinn would be the best place for this particular functionality (because it's useful both in the testing context and for general data engineering).

Would opening a quinn PR with this functionality be alright?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants