Skip to content

Commit

Permalink
Fix handling of ramalama login huggingface
Browse files Browse the repository at this point in the history
Fixes: #465

Signed-off-by: Daniel J Walsh <[email protected]>
  • Loading branch information
rhatdan committed Nov 19, 2024
1 parent c6041e7 commit 03b9b8b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
3 changes: 3 additions & 0 deletions bin/ramalama
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def main(args):
sys.exit(e.returncode)
except KeyboardInterrupt:
sys.exit(0)
except ValueError as e:
ramalama.perror("Error: " + str(e).strip("'"))
sys.exit(errno.EINVAL)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def login_parser(subparsers):

def login_cli(args):
registry = args.REGISTRY
if registry != "":
if registry != "" and registry != "ollama" and registry != "hf" and registry != "huggingface":
registry = "oci://" + registry

model = New(str(registry), args)
Expand All @@ -294,7 +294,7 @@ def logout_parser(subparsers):


def logout_cli(args):
transport = args.transport
transport = args.TRANSPORT
model = New(str(transport), args)
return model.logout(args)

Expand Down Expand Up @@ -756,9 +756,9 @@ def run_container(args):


def New(model, args):
if model.startswith("huggingface://") or model.startswith("hf://"):
if model.startswith("huggingface") or model.startswith("hf://"):
return Huggingface(model)
if model.startswith("ollama://"):
if model.startswith("ollama"):
return Ollama(model)
if model.startswith("oci://") or model.startswith("docker://"):
return OCI(model, args.engine)
Expand Down
35 changes: 10 additions & 25 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
import urllib.request
from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum
from ramalama.model import Model
Expand Down Expand Up @@ -37,7 +38,7 @@ def __init__(self, model):
model = model.removeprefix("huggingface://")
model = model.removeprefix("hf://")
super().__init__(model)
self.type = "HuggingFace"
self.type = "huggingface"
split = self.model.rsplit("/", 1)
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]
Expand All @@ -50,7 +51,7 @@ def login(self, args):
conman_args = ["huggingface-cli", "login"]
if args.token:
conman_args.extend(["--token", args.token])
self.exec(conman_args)
self.exec(conman_args, args)

def logout(self, args):
if not self.hf_cli_available:
Expand All @@ -59,17 +60,7 @@ def logout(self, args):
conman_args = ["huggingface-cli", "logout"]
if args.token:
conman_args.extend(["--token", args.token])
self.exec(conman_args)

def path(self, args):
return self.model_path(args)

def exists(self, args):
model_path = self.model_path(args)
if not os.path.exists(model_path):
return None

return model_path
self.exec(conman_args, args)

def pull(self, args):
model_path = self.model_path(args)
Expand All @@ -93,13 +84,13 @@ def pull(self, args):
if os.path.exists(target_path) and verify_checksum(target_path):
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if not self.check_valid_model_path(relative_target_path, model_path):
run_cmd(["ln", "-sf", relative_target_path, model_path], debug=args.debug)
pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

# Download the model file to the target path
url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}"
download_file(url, target_path, headers={}, show_progress=True)

if not verify_checksum(target_path):
print(f"Checksum mismatch for {target_path}, retrying download...")
os.remove(target_path)
Expand All @@ -112,8 +103,8 @@ def pull(self, args):
# Symlink is already correct, no need to update it
return model_path

run_cmd(["ln", "-sf", relative_target_path, model_path], debug=args.debug)

pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

def push(self, source, args):
Expand All @@ -137,14 +128,8 @@ def push(self, source, args):
)
return proc.stdout.decode("utf-8")

def model_path(self, args):
return os.path.join(args.store, "models", "huggingface", self.directory, self.filename)

def check_valid_model_path(self, relative_target_path, model_path):
return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path

def exec(self, args):
def exec(self, cmd_args, args):
try:
exec_cmd(args, args.debug)
exec_cmd(cmd_args, debug=args.debug)
except FileNotFoundError as e:
print(f"{str(e).strip()}\n{missing_huggingface}")
26 changes: 19 additions & 7 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ramalama.kube import Kube
from ramalama.common import mnt_dir, mnt_file

model_types = ["oci", "huggingface", "hf", "ollama"]


file_not_found = """\
RamaLama requires the "%s" command to be installed on the host when running with --nocontainer.
Expand Down Expand Up @@ -48,9 +50,6 @@ def login(self, args):
def logout(self, args):
raise NotImplementedError(f"ramalama logout for {self.type} not implemented")

def path(self, source, args):
raise NotImplementedError(f"ramalama path for {self.type} not implemented")

def pull(self, args):
raise NotImplementedError(f"ramalama pull for {self.type} not implemented")

Expand All @@ -67,8 +66,7 @@ def is_symlink_to(self, file_path, target_path):
return False

def garbage_collection(self, args):
repo_paths = ["huggingface", "oci", "ollama"]
for repo in repo_paths:
for repo in model_types:
repo_dir = f"{args.store}/repos/{repo}"
model_dir = f"{args.store}/models/{repo}"
for root, dirs, files in os.walk(repo_dir):
Expand Down Expand Up @@ -102,8 +100,6 @@ def remove(self, args):

self.garbage_collection(args)

def model_path(self, args):
raise NotImplementedError(f"model_path for {self.type} not implemented")

def _image(self, args):
if args.image != default_image():
Expand Down Expand Up @@ -343,6 +339,22 @@ def kube(self, model, args, exec_args):
kube = Kube(model, args, exec_args)
kube.generate()

def path(self, args):
return self.model_path(args)

def model_path(self, args):
return os.path.join(args.store, "models", self.type, self.directory, self.filename)

def exists(self, args):
model_path = self.model_path(args)
if not os.path.exists(model_path):
return None

return model_path

def check_valid_model_path(self, relative_target_path, model_path):
return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path


def get_gpu():
i = 0
Expand Down

0 comments on commit 03b9b8b

Please sign in to comment.