Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Jan 25, 2024
1 parent d3b6208 commit c62796d
Showing 1 changed file with 116 additions and 1 deletion.
117 changes: 116 additions & 1 deletion src/dask_awkward/lib/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,119 @@ def _enforce_concatenated_form(array: AwkwardArray, form: Form) -> AwkwardArray:
return ak.Array(result, behavior=array._behavior, attrs=array._attrs)


from awkward.typetracer import TypeTracerReport


class ParentReport(TypeTracerReport):
def __init__(self):
self._parent_to_child: dict[str, tuple[TypeTracerReport, str]] = {}

def add_child_key(
self, parent_key: str, child_key: str, child_report: TypeTracerReport
):
self._parent_to_child.setdefault(parent_key, []).append(
(child_report, child_key)
)

@property
def shape_touched(self):
raise NotImplementedError

@property
def data_touched(self):
raise NotImplementedError

def touch_shape(self, label: str):
if (child_infos := self._parent_to_child.get(label)) is not None:
for child_report, child_label in child_infos:
child_report.touch_shape(child_label)

def touch_data(self, label: str):
if (child_infos := self._parent_to_child.get(label)) is not None:
for child_report, child_label in child_infos:
child_report.touch_data(child_label)


def maybe_parent_report(parent, children, parent_report):
if parent_report is None:
parent_report = ParentReport()
if parent.report is not None:
parent_report.add_child_key(parent.form_key, parent.form_key, parent.report)
for child in children:
if child.report is not None:
parent_report.add_child_key(parent.form_key, child.form_key, child.report)
parent.report = parent_report
return parent_report


def merge_reports(first, *remainder):
parent_report = None

def impl(first, *remainder):
nonlocal parent_report
assert all(type(rem) is type(first) for rem in remainder)

if first.is_numpy:
parent_report = maybe_parent_report(
first.data, [c.data for c in remainder], parent_report
)

elif first.is_option and first.is_indexed:
parent_report = maybe_parent_report(
first.index.data, [c.index.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_option:
parent_report = maybe_parent_report(
first.mask.data, [c.mask.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.ListOffsetArray):
parent_report = maybe_parent_report(
first.offsets.data, [c.offsets.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.ListArray):
parent_report = maybe_parent_report(
first.starts.data, [c.starts.data for c in remainder], parent_report
)
parent_report = maybe_parent_report(
first.stops.data, [c.stops.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.RegularArray):
impl(first.content, *[c.content for c in remainder])

elif first.is_indexed:
parent_report = maybe_parent_report(
first.index.data, [c.index.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_record:
for this, *that in zip(first.contents, *[c.contents for c in remainder]):
impl(this, *that)

elif first.is_empty:
return

elif first.is_union:
raise NotImplementedError

else:
raise AssertionError

impl(first, *remainder)


def _concatenate_axis_0_meta(*arrays: AwkwardArray) -> AwkwardArray:
# At this stage, the metas have all been enforced to the same type
layouts = [arr.layout for arr in arrays]
merge_reports(layouts[0], *layouts)
return arrays[0]


Expand Down Expand Up @@ -119,7 +230,11 @@ def concatenate(
)
}

aml = AwkwardMaterializedLayer(g, previous_layer_names=[arrays[0].name])
aml = AwkwardMaterializedLayer(
g,
previous_layer_names=[a.name for a in arrays],
fn=_concatenate_axis_0_meta,
)

hlg = HighLevelGraph.from_collections(name, aml, dependencies=arrays)
return new_array_object(
Expand Down

0 comments on commit c62796d

Please sign in to comment.