From 03b9b8b4f1621182b6e9c67441e9b7217f79955c Mon Sep 17 00:00:00 2001 From: Daniel J Walsh Date: Tue, 19 Nov 2024 13:58:14 -0500 Subject: [PATCH] Fix handling of ramalama login huggingface Fixes: https://github.com/containers/ramalama/issues/465 Signed-off-by: Daniel J Walsh --- bin/ramalama | 3 +++ ramalama/cli.py | 8 ++++---- ramalama/huggingface.py | 35 ++++++++++------------------------- ramalama/model.py | 26 +++++++++++++++++++------- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/bin/ramalama b/bin/ramalama index 04b097f2..475f2c7c 100755 --- a/bin/ramalama +++ b/bin/ramalama @@ -86,6 +86,9 @@ def main(args): sys.exit(e.returncode) except KeyboardInterrupt: sys.exit(0) + except ValueError as e: + ramalama.perror("Error: " + str(e).strip("'")) + sys.exit(errno.EINVAL) if __name__ == "__main__": diff --git a/ramalama/cli.py b/ramalama/cli.py index 961a1c48..8e1fbeca 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -277,7 +277,7 @@ def login_parser(subparsers): def login_cli(args): registry = args.REGISTRY - if registry != "": + if registry != "" and registry != "ollama" and registry != "hf" and registry != "huggingface": registry = "oci://" + registry model = New(str(registry), args) @@ -294,7 +294,7 @@ def logout_parser(subparsers): def logout_cli(args): - transport = args.transport + transport = args.TRANSPORT model = New(str(transport), args) return model.logout(args) @@ -756,9 +756,9 @@ def run_container(args): def New(model, args): - if model.startswith("huggingface://") or model.startswith("hf://"): + if model.startswith("huggingface") or model.startswith("hf://"): return Huggingface(model) - if model.startswith("ollama://"): + if model.startswith("ollama"): return Ollama(model) if model.startswith("oci://") or model.startswith("docker://"): return OCI(model, args.engine) diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index 50318c97..841d82d4 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -1,4 +1,5 @@ import os +import pathlib import urllib.request from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum from ramalama.model import Model @@ -37,7 +38,7 @@ def __init__(self, model): model = model.removeprefix("huggingface://") model = model.removeprefix("hf://") super().__init__(model) - self.type = "HuggingFace" + self.type = "huggingface" split = self.model.rsplit("/", 1) self.directory = split[0] if len(split) > 1 else "" self.filename = split[1] if len(split) > 1 else split[0] @@ -50,7 +51,7 @@ def login(self, args): conman_args = ["huggingface-cli", "login"] if args.token: conman_args.extend(["--token", args.token]) - self.exec(conman_args) + self.exec(conman_args, args) def logout(self, args): if not self.hf_cli_available: @@ -59,17 +60,7 @@ def logout(self, args): conman_args = ["huggingface-cli", "logout"] if args.token: conman_args.extend(["--token", args.token]) - self.exec(conman_args) - - def path(self, args): - return self.model_path(args) - - def exists(self, args): - model_path = self.model_path(args) - if not os.path.exists(model_path): - return None - - return model_path + self.exec(conman_args, args) def pull(self, args): model_path = self.model_path(args) @@ -93,13 +84,13 @@ def pull(self, args): if os.path.exists(target_path) and verify_checksum(target_path): relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path)) if not self.check_valid_model_path(relative_target_path, model_path): - run_cmd(["ln", "-sf", relative_target_path, model_path], debug=args.debug) + pathlib.Path(model_path).unlink(missing_ok=True) + os.symlink(relative_target_path, model_path) return model_path # Download the model file to the target path url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}" download_file(url, target_path, headers={}, show_progress=True) - if not verify_checksum(target_path): print(f"Checksum mismatch for {target_path}, retrying download...") os.remove(target_path) @@ -112,8 +103,8 @@ def pull(self, args): # Symlink is already correct, no need to update it return model_path - run_cmd(["ln", "-sf", relative_target_path, model_path], debug=args.debug) - + pathlib.Path(model_path).unlink(missing_ok=True) + os.symlink(relative_target_path, model_path) return model_path def push(self, source, args): @@ -137,14 +128,8 @@ def push(self, source, args): ) return proc.stdout.decode("utf-8") - def model_path(self, args): - return os.path.join(args.store, "models", "huggingface", self.directory, self.filename) - - def check_valid_model_path(self, relative_target_path, model_path): - return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path - - def exec(self, args): + def exec(self, cmd_args, args): try: - exec_cmd(args, args.debug) + exec_cmd(cmd_args, debug=args.debug) except FileNotFoundError as e: print(f"{str(e).strip()}\n{missing_huggingface}") diff --git a/ramalama/model.py b/ramalama/model.py index c8d6dcae..3f4ebf4f 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -17,6 +17,8 @@ from ramalama.kube import Kube from ramalama.common import mnt_dir, mnt_file +model_types = ["oci", "huggingface", "hf", "ollama"] + file_not_found = """\ RamaLama requires the "%s" command to be installed on the host when running with --nocontainer. @@ -48,9 +50,6 @@ def login(self, args): def logout(self, args): raise NotImplementedError(f"ramalama logout for {self.type} not implemented") - def path(self, source, args): - raise NotImplementedError(f"ramalama path for {self.type} not implemented") - def pull(self, args): raise NotImplementedError(f"ramalama pull for {self.type} not implemented") @@ -67,8 +66,7 @@ def is_symlink_to(self, file_path, target_path): return False def garbage_collection(self, args): - repo_paths = ["huggingface", "oci", "ollama"] - for repo in repo_paths: + for repo in model_types: repo_dir = f"{args.store}/repos/{repo}" model_dir = f"{args.store}/models/{repo}" for root, dirs, files in os.walk(repo_dir): @@ -102,8 +100,6 @@ def remove(self, args): self.garbage_collection(args) - def model_path(self, args): - raise NotImplementedError(f"model_path for {self.type} not implemented") def _image(self, args): if args.image != default_image(): @@ -343,6 +339,22 @@ def kube(self, model, args, exec_args): kube = Kube(model, args, exec_args) kube.generate() + def path(self, args): + return self.model_path(args) + + def model_path(self, args): + return os.path.join(args.store, "models", self.type, self.directory, self.filename) + + def exists(self, args): + model_path = self.model_path(args) + if not os.path.exists(model_path): + return None + + return model_path + + def check_valid_model_path(self, relative_target_path, model_path): + return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path + def get_gpu(): i = 0