From 08cb065370a0312e395cdc10c0ab707ecc31ff41 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Fri, 25 Oct 2024 19:40:39 +0800 Subject: [PATCH] hot-fix redundant import funasr (#12277) --- python/llm/src/ipex_llm/transformers/npu_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 877ce6168c6..aca3fb252ab 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -93,6 +93,9 @@ def from_pretrained(cls, *args, **kwargs): warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") kwargs["torch_dtype"] = torch.float32 + if hasattr(cls, "get_cls_model"): + cls.HF_Model = cls.get_cls_model() + low_bit = kwargs.pop("load_in_low_bit", "sym_int4") qtype_map = { "sym_int4": "sym_int4_rtn", @@ -574,8 +577,6 @@ class AutoModelForTokenClassification(_BaseAutoModelClass): class FunAsrAutoModel(_BaseAutoModelClass): - import funasr - HF_Model = funasr.AutoModel def __init__(self, *args, **kwargs): self.model = self.from_pretrained(*args, **kwargs) @@ -583,6 +584,12 @@ def __init__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self.model, name) + @classmethod + def get_cls_model(cls): + import funasr + cls_model = funasr.AutoModel + return cls_model + @classmethod def optimize_npu_model(cls, *args, **kwargs): from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr