From 13988c2c37a1d3c2fc4493fc8df1e7f22db1faac Mon Sep 17 00:00:00 2001 From: Mahdi NIKDAN Date: Thu, 27 Jun 2024 10:06:31 +0200 Subject: [PATCH] added support for phi3 --- auto_gptq/modeling/__init__.py | 1 + auto_gptq/modeling/_const.py | 2 ++ auto_gptq/modeling/auto.py | 2 ++ auto_gptq/modeling/phi3.py | 16 ++++++++++++++++ 4 files changed, 21 insertions(+) create mode 100644 auto_gptq/modeling/phi3.py diff --git a/auto_gptq/modeling/__init__.py b/auto_gptq/modeling/__init__.py index ff0bcad..33dc4c8 100644 --- a/auto_gptq/modeling/__init__.py +++ b/auto_gptq/modeling/__init__.py @@ -26,3 +26,4 @@ from .starcoder2 import Starcoder2GPTQForCausalLM from .xverse import XverseGPTQForCausalLM from .yi import YiGPTQForCausalLM +from .phi3 import Phi3GPTQForCausalLM \ No newline at end of file diff --git a/auto_gptq/modeling/_const.py b/auto_gptq/modeling/_const.py index 4461ce8..bfbde61 100644 --- a/auto_gptq/modeling/_const.py +++ b/auto_gptq/modeling/_const.py @@ -44,6 +44,8 @@ SUPPORTED_MODELS.append("gemma") if compare_transformers_version("v4.39.0.dev0", op="ge"): SUPPORTED_MODELS.append("starcoder2") +if compare_transformers_version("v4.40.0.dev", op="ge"): + SUPPORTED_MODELS.append("phi3") EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 diff --git a/auto_gptq/modeling/auto.py b/auto_gptq/modeling/auto.py index 81b98d8..08955b2 100644 --- a/auto_gptq/modeling/auto.py +++ b/auto_gptq/modeling/auto.py @@ -22,6 +22,7 @@ from .mpt import MPTGPTQForCausalLM from .opt import OPTGPTQForCausalLM from .phi import PhiGPTQForCausalLM +from .phi3 import Phi3GPTQForCausalLM from .qwen import QwenGPTQForCausalLM from .qwen2 import Qwen2GPTQForCausalLM from .rw import RWGPTQForCausalLM @@ -59,6 +60,7 @@ "longllama": LongLlamaGPTQForCausalLM, "gemma": GemmaGPTQForCausalLM, "phi": PhiGPTQForCausalLM, + "phi3": Phi3GPTQForCausalLM, "mpt": MPTGPTQForCausalLM, } diff --git a/auto_gptq/modeling/phi3.py b/auto_gptq/modeling/phi3.py new file mode 100644 index 0000000..2e1c8ab --- /dev/null +++ b/auto_gptq/modeling/phi3.py @@ -0,0 +1,16 @@ +from ._base import BaseGPTQForCausalLM + + +class Phi3GPTQForCausalLM(BaseGPTQForCausalLM): + layer_type = "Phi3DecoderLayer" + layers_block_name = "model.layers" + outside_layer_modules = ["model.embed_tokens", "embed_dropout", "model.norm"] + inside_layer_modules = [ + ["self_attn.qkv_proj"], + ["self_attn.o_proj"], + ["mlp.gate_up_proj"], + ["mlp.down_proj"], + ] + + +__all__ = ["Phi3GPTQForCausalLM"] \ No newline at end of file