diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0836870e29..a23a0f5f65 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -642,7 +642,7 @@ def load_checkpoint( self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) - self.load_state_dict(load_fsspec(model_path)["model"], strict=strict) + self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict) if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)