From 705eb8c72e2a448788bf7dd91fd2c4a8a4fc1ef2 Mon Sep 17 00:00:00 2001 From: Dalia Shaaban <144673861+dshaaban01@users.noreply.github.com> Date: Mon, 18 Dec 2023 12:51:56 +0000 Subject: [PATCH] interactive: add diff to operation count table (#1870) ![image](https://github.com/xdslproject/xdsl/assets/144673861/a49b5f5c-0e9b-4b78-bd3c-f23b1c672e8f) Added the diff as per @PapyChacal 's request :) --- tests/interactive/test_pass_metrics.py | 59 +++++++++++++++++++- xdsl/interactive/app.py | 77 ++++++++++++++++---------- xdsl/interactive/app.tcss | 2 +- xdsl/interactive/pass_metrics.py | 31 +++++++++++ 4 files changed, 135 insertions(+), 34 deletions(-) diff --git a/tests/interactive/test_pass_metrics.py b/tests/interactive/test_pass_metrics.py index 31bc97a6ed..60b3c09cf2 100644 --- a/tests/interactive/test_pass_metrics.py +++ b/tests/interactive/test_pass_metrics.py @@ -5,7 +5,10 @@ IntegerAttr, ModuleOp, ) -from xdsl.interactive.pass_metrics import count_number_of_operations +from xdsl.interactive.pass_metrics import ( + count_number_of_operations, + get_diff_operation_count, +) from xdsl.ir import Block, MLContext, Region from xdsl.parser import Parser from xdsl.tools.command_line_tool import get_all_dialects @@ -54,12 +57,62 @@ def test_operation_counter_with_parsing_text(): module = parser.parse_module() expected_res = { - "func.func": 1, "arith.constant": 1, "arith.muli": 1, - "func.return": 1, "builtin.module": 1, + "func.func": 1, + "func.return": 1, } res = count_number_of_operations(module) assert res == expected_res + + +def test_get_diff_operation_count(): + # get input module + input_text = """builtin.module { + func.func @hello(%n : index) -> index { + %two = arith.constant 2 : index + %res = arith.muli %n, %two : index + func.return %res : index + } +} +""" + + ctx = MLContext(True) + for dialect in get_all_dialects(): + ctx.load_dialect(dialect) + parser = Parser(ctx, input_text) + input_module = parser.parse_module() + + # get output module + output_text = """builtin.module { + func.func @hello(%n : index) -> index { + %two = riscv.li 2 : () -> !riscv.reg<> + %two_1 = builtin.unrealized_conversion_cast %two : !riscv.reg<> to index + %res = builtin.unrealized_conversion_cast %n : index to !riscv.reg<> + %res_1 = builtin.unrealized_conversion_cast %two_1 : index to !riscv.reg<> + %res_2 = riscv.mul %res, %res_1 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> + %res_3 = builtin.unrealized_conversion_cast %res_2 : !riscv.reg<> to index + func.return %res_3 : index + } +} +""" + parser = Parser(ctx.clone(), output_text) + output_module = parser.parse_module() + + expected_diff_res: tuple[tuple[str, int, str], ...] = ( + ("arith.constant", 0, "-1"), + ("arith.muli", 0, "-1"), + ("builtin.module", 1, "="), + ("builtin.unrealized_conversion_cast", 4, "+4"), + ("func.func", 1, "="), + ("func.return", 1, "="), + ("riscv.li", 1, "+1"), + ("riscv.mul", 1, "+1"), + ) + + assert expected_diff_res == get_diff_operation_count( + tuple(count_number_of_operations(input_module).items()), + tuple(count_number_of_operations(output_module).items()), + ) diff --git a/xdsl/interactive/app.py b/xdsl/interactive/app.py index 863e1d0a04..5496503aca 100644 --- a/xdsl/interactive/app.py +++ b/xdsl/interactive/app.py @@ -32,7 +32,10 @@ from xdsl.dialects.builtin import ModuleOp from xdsl.interactive.add_arguments_screen import AddArguments from xdsl.interactive.load_file_screen import LoadFile -from xdsl.interactive.pass_metrics import count_number_of_operations +from xdsl.interactive.pass_metrics import ( + count_number_of_operations, + get_diff_operation_count, +) from xdsl.ir import MLContext from xdsl.parser import Parser from xdsl.passes import ModulePass, PipelinePass, get_pass_argument_names_and_types @@ -136,14 +139,23 @@ class InputApp(App[None]): """ListView displaying the passes available to apply.""" input_operation_count_tuple = reactive(tuple[tuple[str, int], ...]) - """Saves the operation name and count of the input text area in a dictionary.""" - output_operation_count_tuple = reactive(tuple[tuple[str, int], ...]) - """Saves the operation name and count of the output text area in a dictionary.""" + """ + Saves the operation name and count of the input text area in a reactive tuple of + tuples. + """ + diff_operation_count_tuple = reactive(tuple[tuple[str, int, str], ...]) + """ + Saves the diff of the input_operation_count_tuple and the output_operation_count_tuple + in a reactive tuple of tuples. + """ input_operation_count_datatable: DataTable[str | int] """DataTable displaying the operation names and counts of the input text area.""" - output_operation_count_datatable: DataTable[str | int] - """DataTable displaying the operation names and counts of the output text area.""" + diff_operation_count_datatable: DataTable[str | int] + """ + DataTable displaying the diff of operation names and counts of the input and output + text areas. + """ def __init__(self): self.input_text_area = TextArea(id="input") @@ -153,8 +165,8 @@ def __init__(self): self.input_operation_count_datatable = DataTable( id="input_operation_count_datatable" ) - self.output_operation_count_datatable = DataTable( - id="output_operation_count_datatable" + self.diff_operation_count_datatable = DataTable( + id="diff_operation_count_datatable" ) super().__init__() @@ -196,7 +208,7 @@ def compose(self) -> ComposeResult: yield self.output_text_area yield Button("Copy Output", id="copy_output_button") with ScrollableContainer(id="output_ops_container"): - yield self.output_operation_count_datatable + yield self.diff_operation_count_datatable yield Footer() def on_mount(self) -> None: @@ -224,8 +236,8 @@ def on_mount(self) -> None: self.input_operation_count_datatable.add_columns("Operation", "Count") self.input_operation_count_datatable.zebra_stripes = True - self.output_operation_count_datatable.add_columns("Operation", "Count") - self.output_operation_count_datatable.zebra_stripes = True + self.diff_operation_count_datatable.add_columns("Operation", "Count", "Diff") + self.diff_operation_count_datatable.zebra_stripes = True def compute_available_pass_list(self) -> tuple[type[ModulePass], ...]: """ @@ -376,7 +388,7 @@ def watch_current_module(self): output_text = output_stream.getvalue() self.output_text_area.load_text(output_text) - self.update_output_operation_count_tuple() + self.update_operation_count_diff_tuple() def get_query_string(self) -> str: """ @@ -394,8 +406,9 @@ def update_input_operation_count_tuple(self, input_module: ModuleOp) -> None: Function that updates the input_operation_datatable to display the operation names and counts in the input text area. """ + # sort tuples alphabetically by operation name self.input_operation_count_tuple = tuple( - count_number_of_operations(input_module).items() + sorted(count_number_of_operations(input_module).items()) ) def watch_input_operation_count_tuple(self) -> None: @@ -403,36 +416,40 @@ def watch_input_operation_count_tuple(self) -> None: Function called when the reactive variable input_operation_count_tuple changes - updates the Input DataTable. """ + # clear datatable and add input_operation_count_tuple to DataTable self.input_operation_count_datatable.clear() - for k, v in self.input_operation_count_tuple: - self.input_operation_count_datatable.add_row(k, v) - - self.update_output_operation_count_tuple() + self.input_operation_count_datatable.add_rows(self.input_operation_count_tuple) + self.update_operation_count_diff_tuple() - def update_output_operation_count_tuple(self) -> None: + def update_operation_count_diff_tuple(self) -> None: """ - Function that updates the output_operation_datatable to display the operation - names and counts in the output text area. It also displays the diff of the input - and output datatable. + Function that updates the diff_operation_count_tuple to calculate the diff + of the input and output operation counts. """ match self.current_module: case None: - self.output_operation_count_tuple = () + output_operation_count_tuple = () case Exception(): - self.output_operation_count_tuple = () + output_operation_count_tuple = () case ModuleOp(): - self.output_operation_count_tuple = tuple( - count_number_of_operations(self.current_module).items() + # sort tuples alphabetically by operation name + output_operation_count_tuple = tuple( + (k, v) + for (k, v) in sorted( + count_number_of_operations(self.current_module).items() + ) ) + self.diff_operation_count_tuple = get_diff_operation_count( + self.input_operation_count_tuple, output_operation_count_tuple + ) - def watch_output_operation_count_tuple(self) -> None: + def watch_diff_operation_count_tuple(self) -> None: """ - Function called when the reactive variable output_operation_count_tuple changes + Function called when the reactive variable diff_operation_count_tuple changes - updates the Output DataTable. """ - self.output_operation_count_datatable.clear() - for k, v in self.output_operation_count_tuple: - self.output_operation_count_datatable.add_row(k, v) + self.diff_operation_count_datatable.clear() + self.diff_operation_count_datatable.add_rows(self.diff_operation_count_tuple) def action_toggle_dark(self) -> None: """An action to toggle dark mode.""" diff --git a/xdsl/interactive/app.tcss b/xdsl/interactive/app.tcss index 0ac4ca20e7..42299be634 100644 --- a/xdsl/interactive/app.tcss +++ b/xdsl/interactive/app.tcss @@ -57,7 +57,7 @@ } # DataTable -#output_operation_count_datatable{ +#diff_operation_count_datatable{ width: auto; height: auto; } diff --git a/xdsl/interactive/pass_metrics.py b/xdsl/interactive/pass_metrics.py index ab30a1beda..291215992d 100644 --- a/xdsl/interactive/pass_metrics.py +++ b/xdsl/interactive/pass_metrics.py @@ -9,3 +9,34 @@ def count_number_of_operations(module: ModuleOp) -> dict[str, int]: occurences of each Operation in the ModuleOp. """ return Counter(op.name for op in module.walk()) + + +def get_diff_operation_count( + input_operation_count_tuple: tuple[tuple[str, int], ...], + output_operation_count_tuple: tuple[tuple[str, int], ...], +) -> tuple[tuple[str, int, str], ...]: + """ + Function returning a tuple of tuples containing the diff of the input and output + operation name and count. + """ + input_op_count_dict = dict(input_operation_count_tuple) + output_op_count_dict = dict(output_operation_count_tuple) + all_keys = {*input_op_count_dict, *output_op_count_dict} + + res: dict[str, tuple[int, str]] = {} + for k in all_keys: + input_count = input_op_count_dict.get(k, 0) + output_count = output_op_count_dict.get(k, 0) + diff = output_count - input_count + + # convert diff to string + if diff == 0: + diff_str = "=" + elif diff > 0: + diff_str = f"+{diff}" + else: + diff_str = str(diff) + + res[k] = (output_count, diff_str) + + return tuple((k, v0, v1) for (k, (v0, v1)) in sorted(res.items()))