From 1d69eac021ae9f1df8ff427f5eec8e4dabaeb306 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 25 Apr 2024 13:57:25 -0500 Subject: [PATCH] Update litserve dependency (#1356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- litgpt/deploy/serve.py | 15 +++++++++++++-- pyproject.toml | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 4a26e0b14f..9cd594230d 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -1,11 +1,12 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from pathlib import Path -from typing import Dict, Any, Optional, Literal +from typing import Dict, Any, Optional from litgpt.utils import check_valid_checkpoint_dir import lightning as L +from lightning_utilities.core.imports import RequirementCache import torch -from litserve import LitAPI, LitServer + from litgpt.model import GPT from litgpt.config import Config @@ -15,6 +16,13 @@ from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision +_LITSERVE_AVAILABLE = RequirementCache("litserve") +if _LITSERVE_AVAILABLE: + from litserve import LitAPI, LitServer +else: + LitAPI, LitServer = object, object + + class SimpleLitAPI(LitAPI): def __init__(self, checkpoint_dir: Path, @@ -23,6 +31,9 @@ def __init__(self, top_k: int = 50, max_new_tokens: int = 50) -> None: + if not _LITSERVE_AVAILABLE: + raise ImportError(str(_LITSERVE_AVAILABLE)) + super().__init__() self.checkpoint_dir = checkpoint_dir self.precision = precision diff --git a/pyproject.toml b/pyproject.toml index d8d60ff594..ba3bc7c9e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,7 @@ license = { file = "LICENSE" } dependencies = [ "torch>=2.2.0", "lightning==2.3.0.dev20240328", - "jsonargparse[signatures]>=4.27.6", - "litserve>=0.1.0" # imported by litgpt.deploy + "jsonargparse[signatures]>=4.27.6" ] [project.urls] @@ -38,6 +37,7 @@ all = [ "tokenizers>=0.15.2", # pythia, falcon, redpajama "requests>=2.31.0", # litgpt.data "litdata>=0.2.2", # litgpt.data + "litserve>=0.1.0", # litgpt.deploy "zstandard>=0.22.0", # litgpt.data.prepare_slimpajama.py "pandas>=1.9.0", # litgpt.data.prepare_starcoder.py "pyarrow>=15.0.2", # litgpt.data.prepare_starcoder.py