diff --git a/tests/model_dec_scales.json b/tests/model_dec_scales.json new file mode 100644 index 0000000..841008a --- /dev/null +++ b/tests/model_dec_scales.json @@ -0,0 +1 @@ +[{"attn_input_scale": 0.031619094488188976, "q_output_scale": 0.1687992125984252, "k_output_scale": 0.1347194881889764, "v_output_scale": 0.02297613188976378, "out_input_scale": 0.01796259842519685, "fc1_input_scale": 0.031619094488188976, "fc2_input_scale": 0.007831723671259843}, {"attn_input_scale": 0.011095903051181102, "q_output_scale": 0.14013287401574803, "k_output_scale": 0.14160925196850394, "v_output_scale": 0.046475147637795276, "out_input_scale": 0.03595595472440945, "fc1_input_scale": 0.011095903051181102, "fc2_input_scale": 0.00655142716535433}, {"attn_input_scale": 0.00863527312992126, "q_output_scale": 0.18553149606299213, "k_output_scale": 0.156373031496063, "v_output_scale": 0.03021961122047244, "out_input_scale": 0.030096579724409447, "fc1_input_scale": 0.00863527312992126, "fc2_input_scale": 0.029789000984251968}, {"attn_input_scale": 0.019192913385826772, "q_output_scale": 0.2233021653543307, "k_output_scale": 0.15563484251968504, "v_output_scale": 0.03804749015748032, "out_input_scale": 0.03168061023622047, "fc1_input_scale": 0.019192913385826772, "fc2_input_scale": 0.03337229330708662}, {"attn_input_scale": 0.01287217027559055, "q_output_scale": 0.13041338582677164, "k_output_scale": 0.1392716535433071, "v_output_scale": 0.062100147637795276, "out_input_scale": 0.05361097440944882, "fc1_input_scale": 0.01287217027559055, "fc2_input_scale": 0.002772053395669291}, {"attn_input_scale": 0.016901451771653545, "q_output_scale": 0.17691929133858267, "k_output_scale": 0.17704232283464566, "v_output_scale": 0.025298351377952756, "out_input_scale": 0.024913877952755906, "fc1_input_scale": 0.016901451771653545, "fc2_input_scale": 0.00285279281496063}, {"attn_input_scale": 0.016378567913385825, "q_output_scale": 0.13188976377952755, "k_output_scale": 0.15243602362204725, "v_output_scale": 0.02449864665354331, "out_input_scale": 0.020100270669291338, "fc1_input_scale": 0.016378567913385825, "fc2_input_scale": 0.0020415538877952754}, {"attn_input_scale": 0.014563853346456693, "q_output_scale": 0.15526574803149606, "k_output_scale": 0.1625246062992126, "v_output_scale": 0.02995816929133858, "out_input_scale": 0.02109990157480315, "fc1_input_scale": 0.014563853346456693, "fc2_input_scale": 0.002793199434055118}, {"attn_input_scale": 0.016701525590551183, "q_output_scale": 0.15255905511811024, "k_output_scale": 0.18061023622047245, "v_output_scale": 0.021345964566929134, "out_input_scale": 0.01842396653543307, "fc1_input_scale": 0.016701525590551183, "fc2_input_scale": 0.00299312561515748}, {"attn_input_scale": 0.017685777559055118, "q_output_scale": 0.16289370078740156, "k_output_scale": 0.18393208661417323, "v_output_scale": 0.02875861220472441, "out_input_scale": 0.026113435039370077, "fc1_input_scale": 0.017685777559055118, "fc2_input_scale": 0.0021876537893700788}, {"attn_input_scale": 0.01819328248031496, "q_output_scale": 0.1875, "k_output_scale": 0.17285925196850394, "v_output_scale": 0.03186515748031496, "out_input_scale": 0.0296505905511811, "fc1_input_scale": 0.01819328248031496, "fc2_input_scale": 0.001685915969488189}, {"attn_input_scale": 0.014271653543307087, "q_output_scale": 0.14480807086614172, "k_output_scale": 0.16510826771653545, "v_output_scale": 0.023622047244094488, "out_input_scale": 0.01714751476377953, "fc1_input_scale": 0.014271653543307087, "fc2_input_scale": 0.0016195943036417322}, {"attn_input_scale": 0.01624015748031496, "q_output_scale": 0.1733513779527559, "k_output_scale": 0.18713090551181102, "v_output_scale": 0.04856668307086614, "out_input_scale": 0.029389148622047244, "fc1_input_scale": 0.01624015748031496, "fc2_input_scale": 0.0015542338213582678}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1546505905511811, "k_output_scale": 0.18639271653543307, "v_output_scale": 0.03380290354330709, "out_input_scale": 0.03257258858267716, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.002921998031496063}, {"attn_input_scale": 0.014686884842519685, "q_output_scale": 0.16203248031496062, "k_output_scale": 0.1969734251968504, "v_output_scale": 0.03071173720472441, "out_input_scale": 0.02066929133858268, "fc1_input_scale": 0.014686884842519685, "fc2_input_scale": 0.0026105745570866143}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1592027559055118, "k_output_scale": 0.18011811023622049, "v_output_scale": 0.028420275590551183, "out_input_scale": 0.014148622047244094, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.005417230561023622}, {"attn_input_scale": 0.017854945866141732, "q_output_scale": 0.17568897637795275, "k_output_scale": 0.19672736220472442, "v_output_scale": 0.023452878937007874, "out_input_scale": 0.02251476377952756, "fc1_input_scale": 0.017854945866141732, "fc2_input_scale": 0.0013398898868110236}, {"attn_input_scale": 0.015286663385826772, "q_output_scale": 0.1671998031496063, "k_output_scale": 0.14271653543307086, "v_output_scale": 0.019239050196850394, "out_input_scale": 0.017593503937007874, "fc1_input_scale": 0.015286663385826772, "fc2_input_scale": 0.0022145669291338582}, {"attn_input_scale": 0.016070989173228346, "q_output_scale": 0.15514271653543307, "k_output_scale": 0.15231299212598426, "v_output_scale": 0.019408218503937008, "out_input_scale": 0.016424704724409447, "fc1_input_scale": 0.016070989173228346, "fc2_input_scale": 0.006243848425196851}, {"attn_input_scale": 0.017009104330708662, "q_output_scale": 0.1422244094488189, "k_output_scale": 0.16117125984251968, "v_output_scale": 0.025221456692913386, "out_input_scale": 0.019500492125984252, "fc1_input_scale": 0.017009104330708662, "fc2_input_scale": 0.01803949311023622}] \ No newline at end of file diff --git a/tests/test_gptj.py b/tests/test_gptj.py new file mode 100644 index 0000000..28b838c --- /dev/null +++ b/tests/test_gptj.py @@ -0,0 +1,97 @@ +import torch +from torch_int.models.gptj import Int8GPTJForCausalLM, Int8GPTJBlock, Int8GPTJMLP, Int8GPTJAttention, Int8GPTJModel +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM +from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig, GPTJForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer +from icecream import ic +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +# from transformers import GPTJTok +from datasets import load_dataset +from tqdm import tqdm +import json +import copy + +class Evaluator: + def __init__(self, dataset, tokenizer, device): + self.dataset = dataset + self.tokenizer = tokenizer + self.device = device + + # tokenize the dataset + def tokenize_function(examples): + example = self.tokenizer(examples['text']) + return example + self.dataset = self.dataset.map(tokenize_function, batched=True) + self.dataset.set_format(type='torch', columns=['input_ids']) + + @torch.no_grad() + def evaluate2(self, model): + model.eval() + # The task is to predict the last token of the input. + total, hit = 0, 0 + idx = 0 + pbar = tqdm(self.dataset, desc='Evaluating') + for batch in pbar: + input_ids = batch['input_ids'].to(self.device).unsqueeze(0) + label = input_ids[:, -1] + outputs = model(input_ids.cuda()) + idx += 1 + last_token_logits = outputs.logits[:, -2, :] + pred = last_token_logits.argmax(dim=-1) + total += label.size(0) + hit += (pred == label).sum().item() + pbar.set_postfix({'acc': hit / total}) + acc = hit / total + return acc + + @torch.no_grad() + def evaluate(self, modelX, model): + model.eval() + # The task is to predict the last token of the input. + idx = 0 + total, hit = 0, 0 + hit2 = 0 + pbar = tqdm(self.dataset, desc='Evaluating') + for batch in pbar: + input_ids = batch['input_ids'].to(self.device).unsqueeze(0) + label = input_ids[:, -1] + outputs = model(input_ids.to('cuda')) + outputs2 = modelX(input_ids.to('cuda')) + model.transformer.d.clear() + modelX.transformer.d.clear() + idx += 1 + last_token_logits = outputs.logits[:, -2, :] + last_token_logits = outputs2.logits[:, -2, :] + pred = last_token_logits.argmax(dim=-1) + pred2 = last_token_logits.argmax(dim=-1) + total += label.size(0) + hit += (pred == label).sum().item() + hit2 += (pred == label).sum().item() + pbar.set_postfix({'acc': hit / total, 'accX': hit2 / total}) + acc = hit / total + return acc + +MP = "/home/iman/fgg/smoothquant/SF/codegen-350M-multiX.pt" +@torch.no_grad() +def test_opt(): + dataset = load_dataset('lambada', split='validation[:1000]') + dataset = dataset.shuffle(seed=42) + checkpoint = "moyix/codegen-350M-multi-gptj" + # checkpoint = "Salesforce/codegen-350M-multi" + config = GPTJConfig.from_pretrained('moyix/codegen-350M-multi-gptj') + model = GPTJForCausalLM.from_pretrained(checkpoint, device_map = 'auto', torch_dtype = 'auto').cuda() + tokenizer = AutoTokenizer.from_pretrained('Salesforce/codegen-350M-multi') + evaluator = Evaluator(dataset, tokenizer, 'cuda') + dlsj = "./tests/model_dec_scales.json" + decoder_layer_scales = [] + with open(dlsj, 'r') as fp: + decoder_layer_scales = json.load(fp) + # these layers will not be quantized + layers_to_keep = list(range(13)) + int8_model = Int8GPTJForCausalLM.from_float(model, decoder_layer_scales, k = layers_to_keep) + acc = evaluator.evaluate2(int8_model.to('cuda')) + ic(acc) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_gptj_attention.py b/tests/test_gptj_attention.py new file mode 100644 index 0000000..dcaaff2 --- /dev/null +++ b/tests/test_gptj_attention.py @@ -0,0 +1,57 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJConfig +from torch_int.models.gptj import Int8GPTJAttention +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial + +def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_attention(): + B, L, D, H = 1, 32, 128, 1 + x = torch.randn(B, L, D) + x_scale = x.abs().max() / 127 + config = GPTJConfig() + config.n_embd = D + config.n_head = H + config.rotary_dim = None + attn = GPTJAttention(config) + attn.eval() + act_dict = {} + for name, module in attn.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = attn(x) + y = y[0] + + q_output_scale = act_dict['q_proj'][1].abs().max() / 127 + k_output_scale = act_dict['k_proj'][1].abs().max() / 127 + v_output_scale = act_dict['v_proj'][1].abs().max() / 127 + out_input_scale = act_dict['out_proj'][0].abs().max() / 127 + int8_attn = Int8GPTJAttention.from_float( + attn, x_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale).cuda() + int8_attn.eval() + q_act_dict = {} + for name, module in int8_attn.named_modules(): + if isinstance(module, (W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU)): + module.register_forward_hook( + partial(store_act, act_dict=q_act_dict, name=name)) + q_x = (x / x_scale).round().to(torch.int8) + y_hat = int8_attn(q_x.cuda())[0].cpu() + + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_attention() diff --git a/tests/test_gptj_block.py b/tests/test_gptj_block.py new file mode 100644 index 0000000..00f8cf8 --- /dev/null +++ b/tests/test_gptj_block.py @@ -0,0 +1,81 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJBlock, GPTJConfig +from torch_int.models.gptj import Int8GPTJBlock +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial +import matplotlib.pyplot as plt + +def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_block(): + config : GPTJConfig = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') + B, L, D, H = 1, 256, config.n_embd, config.n_head + x = torch.randn(B, L, D) + blk = GPTJBlock(config) + blk.eval() + act_dict = {} + for name, module in blk.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + if isinstance(module, torch.nn.LayerNorm): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + + y = blk(x) + y = y[0].cpu() + print(act_dict.keys()) + # exit(0) + ln1_input_scale = act_dict['ln_1'][1].abs().max() / 127 + attn_input_scale = act_dict['attn.q_proj'][0].abs().max() / 127 + q_output_scale = act_dict['attn.q_proj'][1].abs().max() / 127 + k_output_scale = act_dict['attn.k_proj'][1].abs().max() / 127 + v_output_scale = act_dict['attn.v_proj'][1].abs().max() / 127 + out_input_scale = act_dict['attn.out_proj'][0].abs().max() / 127 + fc1_input_scale = act_dict['mlp.fc_in'][0].abs().max() / 127 + fc2_input_scale = act_dict['mlp.fc_out'][0].abs().max() / 127 + int8_blk = Int8GPTJBlock.from_float( + blk, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale, fc1_input_scale, fc2_input_scale).cuda() + int8_blk.eval() + q_act_dict = {} + + y_hat = int8_blk(x.cuda())[0].cpu() + # rd = blk.dbgi + # md = int8_blk.dbgi + # RN = 256 + # ra = rd['atto'].cpu().flatten()[:RN] + # ma = md['attoX'].cpu().flatten()[:RN] + # rf = rd['ffn'].cpu().flatten()[:RN] + # mf = md['ffnX'].cpu().flatten()[:RN] + # rr = rd['resi'].cpu().flatten()[:RN] + # mr = md['resiX'].cpu().flatten()[:RN] + # + # plt.plot(ra.flatten()) + # print(f"MAX: a:{ra.abs().max()} f:{rf.abs().max()} r:{rr.abs().max()+0.0000001}") + # plt.plot(ma - ra, color='r') + # plt.savefig("Xa.jpg", dpi=300) + # plt.cla() + # # plt.plot(rf) + # plt.plot(mf - rf, color='r') + # plt.savefig("Xf.jpg", dpi=300) + # plt.cla() + # # plt.plot(rr.flatten()) + # plt.plot(mr - rr, color='r') + # plt.savefig("Xr.jpg", dpi=300) + + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_block() diff --git a/tests/test_gptj_mlp.py b/tests/test_gptj_mlp.py new file mode 100644 index 0000000..1e21a00 --- /dev/null +++ b/tests/test_gptj_mlp.py @@ -0,0 +1,54 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJMLP, GPTJConfig +from torch_int.models.gptj import Int8GPTJMLP +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial +from torch_int.nn.fused import LayerNormQ +from torch.nn import LayerNorm + +def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_mlp(): + B, L, D, H = 1, 16, 32, 1 + x = torch.randn(B, L, D)*40 + x = torch.clamp(x, -127, 127) + x_scale = x.abs().max() / 127 + config = GPTJConfig() + config.n_embd = D + config.n_head = H + intermediate_size = 4*D + config.rotary_dim = None + mlp = GPTJMLP(intermediate_size, config) + mlp.eval() + act_dict = {} + for name, module in mlp.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = mlp(x) + y = y[0] + + fc_in_scale = act_dict['fc_in'][0].abs().max() / 127 + fc_out_scale = act_dict['fc_out'][0].abs().max() / 127 + int8_mlp = Int8GPTJMLP.from_float( + mlp, fc_in_scale, fc_out_scale).cuda() + int8_mlp.eval() + q_x = x.round().to(torch.int8) + y_hat = int8_mlp(q_x.cuda()).cpu() + print(y_hat.shape) + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_mlp() diff --git a/tests/test_gptj_model.py b/tests/test_gptj_model.py new file mode 100644 index 0000000..6a27cb0 --- /dev/null +++ b/tests/test_gptj_model.py @@ -0,0 +1,61 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig +from torch_int.models.gptj import Int8GPTJModel +from icecream import ic +from functools import partial + + +def store_act(module, x, y, act_dict, name): + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_model_layer(): + config = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') + + B, L, D, H = 1, 256, config.n_embd, config.n_head + + x = torch.randint(0, config.vocab_size, (B, L)) + model = GPTJModel(config) + model.eval() + act_dict = {} + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = model(x)[0].cuda() + decoder_layer_scales = [] + for idx in range(config.n_layer): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"h.{idx}.attn.q_proj"][0].abs( + ).max() / 127 + scale_dict["q_output_scale"] = act_dict[f"h.{idx}.attn.q_proj"][1].abs( + ).max() / 127 + scale_dict["k_output_scale"] = act_dict[f"h.{idx}.attn.k_proj"][1].abs( + ).max() / 127 + scale_dict["v_output_scale"] = act_dict[f"h.{idx}.attn.v_proj"][1].abs( + ).max() / 127 + scale_dict["out_input_scale"] = act_dict[f"h.{idx}.attn.out_proj"][0].abs( + ).max() / 127 + scale_dict["fc1_input_scale"] = act_dict[f"h.{idx}.mlp.fc_in"][0].abs( + ).max() / 127 + scale_dict["fc2_input_scale"] = act_dict[f"h.{idx}.mlp.fc_out"][0].abs( + ).max() / 127 + decoder_layer_scales.append(scale_dict) + + int8_model = Int8GPTJModel.from_float(model, decoder_layer_scales).cuda() + int8_model.eval() + + y_hat = int8_model(x.cuda())[0] + + # # ic(y_hat) + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_model_layer() diff --git a/tests/test_linear_kernels.py b/tests/test_linear_kernels.py index b134579..c649d05 100644 --- a/tests/test_linear_kernels.py +++ b/tests/test_linear_kernels.py @@ -1,5 +1,5 @@ import torch -from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32 +from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32, linear_gelu_a8_w8_b8_o8 from icecream import ic @@ -85,6 +85,23 @@ def test_quant_linear_relu_a8_w8_b8_o8(): ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1)) +@torch.no_grad() +def test_quant_linear_gelu_a8_w8_b8_o8(): + B, M, N = 128, 512, 1024 + weight = torch.randint(-128, 127, (N, M), dtype=torch.int8) + bias = torch.randint(-128, 127, (N,), dtype=torch.int8) + x = torch.randint(-128, 127, (B, M), dtype=torch.int8) + alpha, beta = 0.001, 0.01 + linear = torch.nn.Linear(M, N, bias=True) + linear.weight.data = weight.float() * alpha + linear.bias.data = bias.float() * beta + y_gt = linear(x.float()) + y_gt = y_gt.clamp(0, 127).round().long() + y = linear_gelu_a8_w8_b8_o8(x.cuda(), weight.cuda(), + bias.cuda(), alpha, beta).cpu().long() + ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1)) + + if __name__ == '__main__': print('test_quant_linear_a8_w8_b32_o32') test_quant_linear_a8_w8_b32_o32() @@ -96,3 +113,5 @@ def test_quant_linear_relu_a8_w8_b8_o8(): test_quant_linear_a8_w8_b8_o8() print('test_quant_linear_relu_a8_w8_b8_o8') test_quant_linear_relu_a8_w8_b8_o8() + print('test_quant_linear_gelu_a8_w8_b8_o8') + test_quant_linear_gelu_a8_w8_b8_o8() diff --git a/torch_int/kernels/bindings.cpp b/torch_int/kernels/bindings.cpp index 4eaf7bc..bbe3398 100644 --- a/torch_int/kernels/bindings.cpp +++ b/torch_int/kernels/bindings.cpp @@ -5,6 +5,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("linear_relu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, "Linear ReLU (INT8)"); + m.def("linear_gelu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, + "Linear ReLU (INT8)"); m.def("linear_a8_w8_b32_o32", &linear_a8_w8_b32_o32, "Linear (INT32)"); m.def("linear_a8_w8_bfp32_ofp32", &linear_a8_w8_bfp32_ofp32, "Linear (I8-OFP32)"); diff --git a/torch_int/kernels/include/linear.h b/torch_int/kernels/include/linear.h index 5df6ac6..ddccd97 100644 --- a/torch_int/kernels/include/linear.h +++ b/torch_int/kernels/include/linear.h @@ -32,6 +32,14 @@ torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 float beta // FP32 ); +// used by fc1, return INT8 +torch::Tensor linear_gelu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +); + // used by q_proj, k_proj, v_proj, return INT8 torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 torch::Tensor weight, // INT8 diff --git a/torch_int/kernels/linear.cu b/torch_int/kernels/linear.cu index 0e11d7b..d87159c 100644 --- a/torch_int/kernels/linear.cu +++ b/torch_int/kernels/linear.cu @@ -487,5 +487,111 @@ torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 std::to_string((int)status)); } + return out; +} + +// used by fc1 +torch::Tensor linear_gelu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement, status: " + + std::to_string((int)status)); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize, status: " + + std::to_string((int)status)); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run, status: " + + std::to_string((int)status)); + } + return out; } \ No newline at end of file diff --git a/torch_int/models/gptj.py b/torch_int/models/gptj.py new file mode 100644 index 0000000..ba03bf3 --- /dev/null +++ b/torch_int/models/gptj.py @@ -0,0 +1,480 @@ +import torch +from torch import nn +from transformers.models.gptj.modeling_gptj import ( + GPTJConfig, + GPTJForCausalLM, + GPTJModel, + GPTJPreTrainedModel, + GPTJAttention, + GPTJMLP, + GPTJBlock, + BaseModelOutputWithPast +) + +@torch.no_grad() +def quantize_per_tensor_absmax(t): + scale = t.abs().max() / 127 + if not t.is_cuda: + # half rounding is not supported on CPU + t = t.float() + # use inplace operation to save memory + t.div_(scale).round_() + t_q = t.to(torch.int8) + return t_q, scale + +from typing import Optional, Tuple, List +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from torch_int.nn.fused import LayerNormQ +from transformers.utils import logging +from torch_int.nn.bmm import BMM_S8T_S8N_S8T, BMM_S8T_S8N_F32T +from transformers.activations import ACT2FN + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float() + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + + +def apply_rotary_pos_emb(x, sincos, offset=0): + x_ = x.to(torch.float32) + sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + r = ((x_.to(torch.float) * cos) + (rotate_every_two(x_.to(torch.float)) * sin)) + r = r.clamp(-128, 127).to(torch.int8) + return r + + +class Int8GPTJAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): + super().__init__() + max_positions = max_position_embeddings + self.max_position = max_positions + self.embed_dim = n_embd + self.num_attention_heads = n_head + self.head_dim = n_embd // n_head + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + + if (self.head_dim * self.num_attention_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_attention_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + self.k_proj = W8A8B8O8Linear(n_embd, n_embd) + self.v_proj = W8A8B8O8Linear(n_embd, n_embd) + self.q_proj = W8A8B8O8Linear(n_embd, n_embd) + self.out_proj = W8A8BFP32OFP32Linear(n_embd, n_embd) + self.rotary_dim = None + if rotary_dim is not None: + self.rotary_dim = rotary_dim + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2).contiguous() + + @staticmethod + @torch.no_grad() + def from_float(module: GPTJAttention, + input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float): + int8_module = Int8GPTJAttention(module.embed_dim, module.num_attention_heads, module.bias.shape[3], module.rotary_dim) + # Fuse the scaling into the q_proj output scale + # scale_h = module.head_dim**-0.5 + ## scaling + # qoo = q_output_scale + # q_output_scale = q_output_scale * scale_h + # module.q_proj.weight *= scale_h + # qs2 = q_output_scale * scale_h + ## scaling + # TODO: GPTJ has no bias, find a way to elide these later + module.q_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.q_proj.weight.dtype)) + module.v_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.v_proj.weight.dtype)) + module.k_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.k_proj.weight.dtype)) + module.out_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.out_proj.weight.dtype)) + module.cuda() + int8_module.q_proj = W8A8B8O8Linear.from_float( + module.q_proj, input_scale, q_output_scale) + wc = module.k_proj.weight.clone() + int8_module.k_proj = W8A8B8O8Linear.from_float( + module.k_proj, input_scale, k_output_scale) + int8_weight, weight_scale = quantize_per_tensor_absmax(wc) + int8_module.v_proj = W8A8B8O8Linear.from_float( + module.v_proj, input_scale, v_output_scale) + int8_module.v_proj.requires_grad = False + int8_module.out_proj = W8A8BFP32OFP32Linear.from_float( + module.out_proj, out_input_scale) + int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale( + q_output_scale, k_output_scale) + # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + # print(f"{v_output_scale}/{out_input_scale}") + int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale( + 1.0 / 127, v_output_scale, out_input_scale) + return int8_module + + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) + elif len(tensor.shape) == 4: + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - + query_length: key_length, :key_length].to(torch.bool).cuda() + + # key = key.transpose(-1, -2) + proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) + key = key.reshape(*proj_shape) + query = query.view(*proj_shape) + query = query.contiguous() + key = key.contiguous() + attn_weights = self.qk_bmm(query, key) + attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor( + mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights.mul_(127).round_() + attn_weights = attn_weights.to(torch.int8) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + attn_weights = attn_weights.view(self.bsz * self.num_attention_heads, -1, self.tgt_len).contiguous() + value = value.transpose(2,3) + value = value.reshape(self.num_attention_heads * self.bsz, self.head_dim, self.tgt_len).contiguous() + attn_output = self.pv_bmm(attn_weights, value) + attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) + attn_output = attn_output.view(self.bsz, self.num_attention_heads, self.tgt_len, self.head_dim) + return attn_output, attn_weights + + def forward( + self, + hidden_states: Optional[torch.Tensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + self.bsz, self.tgt_len, _ = hidden_states.size() + # self.out_proj.cuda() + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads( + query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads( + key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads( + value, self.num_attention_heads, self.head_dim, False) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim:] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim:] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass.to(torch.int8)], dim=-1) + query = torch.cat([q_rot, q_pass.to(torch.int8)], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_dim) + attn_output = attn_output.contiguous() + attn_output = self.out_proj(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class Int8GPTJMLP(nn.Module): + # in MLP: intermediate_size= 4 * embed_dim + def __init__(self, intermediate_size, embed_dim): + super().__init__() + + self.fc1 = W8A8B8O8LinearGELU(embed_dim, intermediate_size) + self.fc2 = W8A8BFP32OFP32Linear(intermediate_size, embed_dim) + + def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor: + # hidden_states = hidden_states.to(torch.float) + hidden_states = self.fc1(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + @staticmethod + def from_float(module: GPTJMLP, fc1_input_scale: float, fc2_input_scale: float): + int8_module = Int8GPTJMLP( + module.fc_in.out_features, module.fc_in.in_features) + int8_module.fc1 = W8A8B8O8LinearGELU.from_float( + module.fc_in, fc1_input_scale, fc2_input_scale) + int8_module.fc2 = W8A8BFP32OFP32Linear.from_float( + module.fc_out, fc2_input_scale) + return int8_module + + +class Int8GPTJBlock(nn.Module): + def __init__(self, inner_dim, n_embd, n_head, max_position_embeddings, rotary_dim = None): + super().__init__() + self.ln_1 = LayerNormQ(n_embd) + self.attn = Int8GPTJAttention(n_embd, n_head, max_position_embeddings, rotary_dim) + self.mlp = Int8GPTJMLP(inner_dim, n_embd) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + return outputs # hidden_states, present, (attentions) + + @staticmethod + def from_float(module, attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + fc1_input_scale: float, + fc2_input_scale: float): + inner_dim = module.mlp.fc_out.in_features + n_embd = module.ln_1.normalized_shape[0] + int8_module = Int8GPTJBlock(inner_dim, n_embd, module.attn.num_attention_heads, module.attn.bias.shape[0], module.attn.rotary_dim) + int8_module.mlp = Int8GPTJMLP.from_float( + module.mlp, fc1_input_scale, fc2_input_scale) + int8_module.ln_1 = LayerNormQ.from_float(module.ln_1, attn_input_scale) + int8_module.attn = Int8GPTJAttention.from_float( + module.attn, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale) + return int8_module + + +class Int8GPTJModel(GPTJPreTrainedModel): + # TODO: have to add padding! + def __init__(self, config): + self.d = {} + super().__init__(config) + n_layer = config.n_layer + inner_dim = 4 * config.n_embd + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Identity() + self.padding_idx = config.pad_token_id + # self.h = nn.ModuleList() + self.h = nn.ModuleList([Int8GPTJBlock(inner_dim, self.embed_dim, config.n_head, config.n_positions, config.rotary_dim) + for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + get_input_embeddings = GPTJModel.get_input_embeddings + set_input_embeddings = GPTJModel.set_input_embeddings + old_forward = GPTJModel.forward + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + from torch.nn.functional import pad + input_len = input_ids.shape[1] + if input_len % 16 != 0: + padding_len = 16 - input_len % 16 + input_ids = pad(input_ids, (0, padding_len), value=self.padding_idx) + if attention_mask is not None: + attention_mask = pad(attention_mask, (0, padding_len), value=0) + output = self.old_forward(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) + if input_len % 16 != 0: + output.last_hidden_state = output.last_hidden_state[:,:input_len, :] + return output + + @staticmethod + def from_float(module : GPTJModel, decoder_layer_scales, k = None): + config = GPTJConfig(vocab_size=module.vocab_size, n_embd=module.embed_dim, n_layer=len(module.h), rotary_dim=module.h[0].attn.rotary_dim + , n_inner=4*module.embed_dim) + int8_module = Int8GPTJModel(config) + for i, layer in enumerate(module.h): + if k is not None and i in k: + int8_module.h[i] = layer + else: + int8_module.h[i] = Int8GPTJBlock.from_float(layer, **decoder_layer_scales[i]) + int8_module.ln_f = module.ln_f.to(torch.float) + int8_module.wte = module.wte + return int8_module + + +class Int8GPTJForCausalLM(GPTJPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = Int8GPTJModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def from_float(module, decoder_layer_scales, k = None): + int8_module = Int8GPTJForCausalLM(module.config) + int8_module.transformer = Int8GPTJModel.from_float(module.transformer, decoder_layer_scales, k) + int8_module.lm_head = module.lm_head.to(torch.float) + return int8_module + + get_input_embeddings = GPTJForCausalLM.get_input_embeddings + set_input_embeddings = GPTJForCausalLM.set_input_embeddings + get_output_embeddings = GPTJForCausalLM.get_output_embeddings + set_output_embeddings = GPTJForCausalLM.set_output_embeddings + forward = GPTJForCausalLM.forward + prepare_inputs_for_generation = GPTJForCausalLM.prepare_inputs_for_generation + _reorder_cache = GPTJForCausalLM._reorder_cache + parallelize = GPTJForCausalLM.parallelize + deparallelize = GPTJForCausalLM.deparallelize diff --git a/torch_int/nn/linear.py b/torch_int/nn/linear.py index 1a6e7b7..2efb247 100644 --- a/torch_int/nn/linear.py +++ b/torch_int/nn/linear.py @@ -1,6 +1,7 @@ import torch from .._CUDA import (linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, + linear_gelu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32 @@ -56,6 +57,49 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale): int8_module.b = beta return int8_module +class W8A8B8O8LinearGELU(torch.nn.Module): + # For fc1 + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('b', torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_gelu_a8_w8_b8_o8(x, self.weight, self.bias, + self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + # TODO: add zero-point to prevent the bit waste + int8_module = W8A8B8O8LinearGELU( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias.reshape(int8_module.bias.shape) + int8_module.a = alpha + int8_module.b = beta + return int8_module class W8A8B8O8LinearReLU(torch.nn.Module): # For fc1 @@ -222,7 +266,7 @@ def from_float(module: torch.nn.Linear, input_scale): int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) alpha = input_scale * weight_scale int8_module.weight = int8_weight - int8_module.bias = module.bias.to(torch.float32) + int8_module.bias = module.bias.to(torch.float32).reshape(int8_module.bias.shape) int8_module.a = alpha int8_module.input_scale = input_scale int8_module.weight_scale = weight_scale