From 6f516b8d4c154f79c2f86fbd8f702dd7584df2d3 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 16 Nov 2023 14:17:13 +0100 Subject: [PATCH] Fixed multi-GPU quantization (#196) --- awq/quantize/quantizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 36eda277..dfa4ef6b 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -69,8 +69,15 @@ def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: to def quantize(self): for i in tqdm(range(len(self.modules)), desc="AWQ"): + # Move module and inputs to correct device + common_device = next(self.modules[i].parameters()).device + if common_device is None or str(common_device) == "cpu": + self.modules[i] = self.modules[i].cuda() + common_device = next(self.modules[i].parameters()).device + + self.inps = self.inps.to(common_device) + # [STEP 1]: Get layer, extract linear modules, extract input features - self.modules[i] = self.modules[i].cuda() named_linears = get_named_linears(self.modules[i]) input_feat = self._get_input_feat(self.modules[i], named_linears) clear_memory()