Skip to content

Commit

Permalink
Add Torch Dependency Resolution
Browse files Browse the repository at this point in the history
* Use our compatibility matrix to automatically
resolve the torch version if the torch version is
blank but adjacent dependencies are defined.
* Make torchaudio capable of determining CUDA
version compatibility like torchvision
* Rewrite dependency versions to align with
resolved versions
* Remove version modifiers, this is now satisfied
by adding in extra index URLs to the requirements
  • Loading branch information
8W9aG committed Sep 12, 2024
1 parent 9d16cfc commit ecf7413
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 18 deletions.
71 changes: 62 additions & 9 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ func (c *TorchCompatibility) TorchvisionVersion() string {
return version.StripModifier(c.Torchvision)
}

func (c *TorchCompatibility) TorchaudioVersion() string {
return version.StripModifier(c.Torchaudio)
}

type CUDABaseImage struct {
Tag string
CUDA string
Expand Down Expand Up @@ -287,12 +291,12 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err
func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchVersion() == ver && compat.CUDA == nil {
return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
return TorchPackageName, torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}

// Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions.
return "torch", ver, "", "", nil
return TorchPackageName, ver, "", "", nil
}

func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
Expand Down Expand Up @@ -327,20 +331,20 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
}
if latest == nil {
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
return "torch", ver, "", "", nil
return TorchPackageName, ver, "", "", nil
}

return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
return TorchPackageName, version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
return TorchvisionPackageName, torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return "torchvision", ver, "", "", nil
return TorchvisionPackageName, ver, "", "", nil
}

func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
Expand All @@ -350,7 +354,7 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
var latest *TorchCompatibility
for _, compat := range TorchCompatibilityMatrix {
compat := compat
if compat.TorchvisionVersion() != ver || compat.CUDA == nil {
if !version.Matches(compat.TorchvisionVersion(), ver) || compat.CUDA == nil {
continue
}
greater, err := versionGreater(*compat.CUDA, cuda)
Expand All @@ -376,10 +380,59 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
if latest == nil {
// TODO: can we suggest a CUDA version known to be compatible?
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver)
return "torchvision", ver, "", "", nil
return TorchvisionPackageName, ver, "", "", nil
}

return TorchvisionPackageName, version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchaudioGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
var latest *TorchCompatibility
for _, compat := range TorchCompatibilityMatrix {
compat := compat
taVersion := compat.TorchaudioVersion()
if taVersion == "" {
continue
}
if !version.Matches(taVersion, ver) || compat.CUDA == nil {
continue
}
greater, err := versionGreater(*compat.CUDA, cuda)
if err != nil {
panic(fmt.Sprintf("Invalid CUDA version: %s", err))
}
if greater {
continue
}
if latest == nil {
latest = &compat
} else {
greater, err := versionGreater(*compat.CUDA, *latest.CUDA)
if err != nil {
// should never happen
panic(fmt.Sprintf("Invalid CUDA version: %s", err))
}
if greater {
latest = &compat
}
}
}
if latest == nil {
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchaudio %s. This might cause CUDA problems.", cuda, ver)
return TorchaudioPackageName, ver, "", "", nil
}

return TorchaudioPackageName, version.StripModifier(latest.Torchaudio), latest.FindLinks, latest.ExtraIndexURL, nil
}

return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
func torchaudioCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if version.Matches(compat.TorchaudioVersion(), ver) && compat.CUDA == nil {
return TorchaudioPackageName, torchStripCPUSuffixForM1(compat.Torchaudio, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return TorchaudioPackageName, ver, "", "", nil
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
Expand Down
76 changes: 69 additions & 7 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ const (
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMajorCudaVersion int = 11
TorchPackageName = "torch"
TorchvisionPackageName = "torchvision"
TorchaudioPackageName = "torchaudio"
)

type RunItem struct {
Expand Down Expand Up @@ -175,15 +178,46 @@ func (c *Config) CUDABaseImageTag() (string, error) {
}

func (c *Config) TorchVersion() (string, bool) {
return c.pythonPackageVersion("torch")
torchVersion, found := c.pythonPackageVersion(TorchPackageName)
if !found {
// Can we determine the torch version based on other packages related to it?
tvVersion, tvVersionFound := c.TorchvisionVersion()
tvVersion = version.StripModifier(tvVersion)
if tvVersionFound {
for _, compat := range TorchCompatibilityMatrix {
if version.Equal(tvVersion, version.StripModifier(compat.Torchvision)) {
return compat.Torch, true
}
}
}

taVersion, taVersionFound := c.TorchaudioVersion()
taVersion = version.StripModifier(taVersion)
if taVersionFound {
for _, compat := range TorchCompatibilityMatrix {
if version.Equal(taVersion, version.StripModifier(compat.Torchaudio)) {
return compat.Torch, true
}
}
}
}
return torchVersion, found
}

func (c *Config) TorchvisionVersion() (string, bool) {
return c.pythonPackageVersion("torchvision")
tvVersion, found := c.pythonPackageVersion(TorchvisionPackageName)
if found {
tvVersion = version.StripModifier(tvVersion)
}
return tvVersion, found
}

func (c *Config) TorchaudioVersion() (string, bool) {
return c.pythonPackageVersion("torchaudio")
taVersion, found := c.pythonPackageVersion(TorchaudioPackageName)
if found {
taVersion = version.StripModifier(taVersion)
}
return taVersion, found
}

func (c *Config) TensorFlowVersion() (string, bool) {
Expand Down Expand Up @@ -380,8 +414,12 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
// It's not pinned, so just return the line verbatim, unless its one of our special packages
if pkg == TorchPackageName || pkg == TorchvisionPackageName || pkg == TorchaudioPackageName {
name = pkg
} else {
return pkg, []string{}, []string{}, nil
}
}
if len(extraIndexURLs) > 0 {
return name + "==" + version, findLinksList, extraIndexURLs, nil
Expand All @@ -398,7 +436,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
}
}
// There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed
case "torch":
case TorchPackageName:
torchVersion, found := c.TorchVersion()
if found {
version = torchVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA)
if err != nil {
Expand All @@ -410,7 +452,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
return "", nil, nil, err
}
}
case "torchvision":
case TorchvisionPackageName:
tvVersion, found := c.TorchvisionVersion()
if found {
version = tvVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA)
if err != nil {
Expand All @@ -422,6 +468,22 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
return "", nil, nil, err
}
}
case TorchaudioPackageName:
taVersion, found := c.TorchaudioVersion()
if found {
version = taVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchaudioGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
} else {
name, version, findLinks, extraIndexURL, err = torchaudioCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
}
}
pkgWithVersion := name
if version != "" {
Expand Down
86 changes: 86 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,89 @@ torch==2.4.0
torchvision==2.4.0`
require.Equal(t, expected, requirements)
}

func TestResolveTorchVersionWithTorchVisionDependency(t *testing.T) {
config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonPackages: []string{
"torch",
"torchvision==0.19.0+cu121",
},
CUDA: "12.1",
},
}
err := config.ValidateAndComplete("")
require.NoError(t, err)
torchVersion, found := config.TorchVersion()
require.True(t, found)
require.Equal(t, "2.4.0", torchVersion)
}

func TestResolveTorchVersionWithTorchAudioDependency(t *testing.T) {
config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonPackages: []string{
"torch",
"torchaudio==2.4.0",
},
CUDA: "12.1",
},
}
err := config.ValidateAndComplete("")
require.NoError(t, err)
torchVersion, found := config.TorchVersion()
require.True(t, found)
require.Equal(t, "2.4.0", torchVersion)
}

func TestPythonRequirementsWithTorchvisionVersionModifierStripped(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch
torchvision==0.19.0+cu121`), 0o644)
require.NoError(t, err)

config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
err = config.ValidateAndComplete(tmpDir)
require.NoError(t, err)

requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.4.0
torchvision==0.19.0`
require.Equal(t, expected, requirements)
}

func TestPythonRequirementsWithTorchaudio(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch
torchaudio==2.4.0`), 0o644)
require.NoError(t, err)

config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
err = config.ValidateAndComplete(tmpDir)
require.NoError(t, err)

requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.4.0
torchaudio==2.4.0`
require.Equal(t, expected, requirements)
}
4 changes: 2 additions & 2 deletions pkg/dockerfile/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (g *BaseImageGenerator) pythonPackages() []string {
continue
}

pkgs = append(pkgs, "torchvision=="+compat.Torchvision)
pkgs = append(pkgs, "torchvision=="+version.StripModifier(compat.Torchvision))
break
}

Expand All @@ -239,7 +239,7 @@ func (g *BaseImageGenerator) pythonPackages() []string {
continue
}

pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio)
pkgs = append(pkgs, "torchaudio=="+version.StripModifier(compat.Torchaudio))
break
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
build:
gpu: true
python_version: "3.9"
python_packages:
- "torch"
- "torchvision==0.19.0+cu121"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, s: str) -> str:
return "hello " + s
15 changes: 15 additions & 0 deletions test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,18 @@ def test_torch_1_13_0_base_image_fail_explicit(docker_image):
capture_output=True,
)
assert build_process.returncode == 0


def test_torch_version_resolution_with_torchvision(docker_image):
project_dir = Path(__file__).parent / "fixtures/torch-torchvision-project"
build_process = subprocess.run(
[
"cog",
"build",
"-t",
docker_image,
],
cwd=project_dir,
capture_output=True,
)
assert build_process.returncode == 0

0 comments on commit ecf7413

Please sign in to comment.