Skip to content

Commit

Permalink
REF: lreshape, wide_to_long (#55976)
Browse files Browse the repository at this point in the history
* Refactor lreshape

* Refactor wide_to_long validation

* Refactor wide_to_long

* Annotation
  • Loading branch information
mroeschke authored Nov 20, 2023
1 parent 484ec01 commit d99c448
Showing 1 changed file with 23 additions and 35 deletions.
58 changes: 23 additions & 35 deletions pandas/core/reshape/melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pandas.core.dtypes.missing import notna

import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical
from pandas.core.indexes.api import MultiIndex
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import tile_compat
Expand Down Expand Up @@ -139,7 +138,7 @@ def melt(
return result


def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
def lreshape(data: DataFrame, groups: dict, dropna: bool = True) -> DataFrame:
"""
Reshape wide-format data to long. Generalized inverse of DataFrame.pivot.
Expand Down Expand Up @@ -192,30 +191,20 @@ def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
2 Red Sox 2008 545
3 Yankees 2008 526
"""
if isinstance(groups, dict):
keys = list(groups.keys())
values = list(groups.values())
else:
keys, values = zip(*groups)

all_cols = list(set.union(*(set(x) for x in values)))
id_cols = list(data.columns.difference(all_cols))

K = len(values[0])

for seq in values:
if len(seq) != K:
raise ValueError("All column lists must be same length")

mdata = {}
pivot_cols = []

for target, names in zip(keys, values):
all_cols: set[Hashable] = set()
K = len(next(iter(groups.values())))
for target, names in groups.items():
if len(names) != K:
raise ValueError("All column lists must be same length")
to_concat = [data[col]._values for col in names]

mdata[target] = concat_compat(to_concat)
pivot_cols.append(target)
all_cols = all_cols.union(names)

id_cols = list(data.columns.difference(all_cols))
for col in id_cols:
mdata[col] = np.tile(data[col]._values, K)

Expand Down Expand Up @@ -467,10 +456,10 @@ def wide_to_long(
two 2.9
"""

def get_var_names(df, stub: str, sep: str, suffix: str) -> list[str]:
def get_var_names(df, stub: str, sep: str, suffix: str):
regex = rf"^{re.escape(stub)}{re.escape(sep)}{suffix}$"
pattern = re.compile(regex)
return [col for col in df.columns if pattern.match(col)]
return df.columns[df.columns.str.match(pattern)]

def melt_stub(df, stub: str, i, j, value_vars, sep: str):
newdf = melt(
Expand All @@ -480,7 +469,6 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
value_name=stub.rstrip(sep),
var_name=j,
)
newdf[j] = Categorical(newdf[j])
newdf[j] = newdf[j].str.replace(re.escape(stub + sep), "", regex=True)

# GH17627 Cast numerics suffixes to int/float
Expand All @@ -497,7 +485,7 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
else:
stubnames = list(stubnames)

if any(col in stubnames for col in df.columns):
if df.columns.isin(stubnames).any():
raise ValueError("stubname can't be identical to a column name")

if not is_list_like(i):
Expand All @@ -508,18 +496,18 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
if df[i].duplicated().any():
raise ValueError("the id variables need to uniquely identify each row")

value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames]

value_vars_flattened = [e for sublist in value_vars for e in sublist]
id_vars = list(set(df.columns.tolist()).difference(value_vars_flattened))
_melted = []
value_vars_flattened = []
for stub in stubnames:
value_var = get_var_names(df, stub, sep, suffix)
value_vars_flattened.extend(value_var)
_melted.append(melt_stub(df, stub, i, j, value_var, sep))

_melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)]
melted = _melted[0].join(_melted[1:], how="outer")
melted = concat(_melted, axis=1)
id_vars = df.columns.difference(value_vars_flattened)
new = df[id_vars]

if len(i) == 1:
new = df[id_vars].set_index(i).join(melted)
return new

new = df[id_vars].merge(melted.reset_index(), on=i).set_index(i + [j])

return new
return new.set_index(i).join(melted)
else:
return new.merge(melted.reset_index(), on=i).set_index(i + [j])

0 comments on commit d99c448

Please sign in to comment.