diff --git a/python/__init__.py b/python/__init__.py index 3763cd64fb..39c0178f23 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -420,10 +420,12 @@ def get_kernel_bin(self): return "spvbin" def get_architecture_descriptor(self, **kwargs): - dev_props = self.driver.utils.get_device_properties(torch.xpu.device(torch.xpu.current_device()).sycl_device) # noqa: E501 - max_work_group_size = dev_props['max_work_group_size'] - max_num_sub_groups = dev_props['max_num_sub_groups'] - sub_group_sizes = dev_props['sub_group_sizes'] + arch = kwargs.get("arch", None) + if arch is None: + arch = self.get_device_properties(self.get_current_device()) + max_work_group_size = arch['max_work_group_size'] + max_num_sub_groups = arch['max_num_sub_groups'] + sub_group_sizes = arch['sub_group_sizes'] # TODO: chose a reasonable subgroup size threads_per_warp = 32 assert threads_per_warp in sub_group_sizes, "Current platform does not support threads_per_warp to be 32" # noqa: E501