Skip to content

Commit

Permalink
Fixed multi-GPU quantization (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 16, 2023
1 parent 74d0fe4 commit 6f516b8
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,15 @@ def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: to

def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
self.modules[i] = self.modules[i].cuda()
common_device = next(self.modules[i].parameters()).device

self.inps = self.inps.to(common_device)

# [STEP 1]: Get layer, extract linear modules, extract input features
self.modules[i] = self.modules[i].cuda()
named_linears = get_named_linears(self.modules[i])
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
Expand Down

0 comments on commit 6f516b8

Please sign in to comment.