From c19dfdc5b5f70757a366216a22c48ff68664818a Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Thu, 12 Sep 2024 17:34:10 -0400 Subject: [PATCH 1/2] Add Torch Dependency Resolution * Use our compatibility matrix to automatically resolve the torch version if the torch version is blank but adjacent dependencies are defined. * Make torchaudio capable of determining CUDA version compatibility like torchvision * Rewrite dependency versions to align with resolved versions * Remove version modifiers, this is now satisfied by adding in extra index URLs to the requirements --- pkg/config/compatibility.go | 71 +++++++++++++-- pkg/config/config.go | 76 ++++++++++++++-- pkg/config/config_test.go | 86 +++++++++++++++++++ pkg/dockerfile/base.go | 4 +- .../torch-torchvision-project/cog.yaml | 7 ++ .../torch-torchvision-project/predict.py | 6 ++ .../test_integration/test_build.py | 15 ++++ 7 files changed, 247 insertions(+), 18 deletions(-) create mode 100644 test-integration/test_integration/fixtures/torch-torchvision-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/torch-torchvision-project/predict.py diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index 4f7fc7620e..60bbcf9427 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -68,6 +68,10 @@ func (c *TorchCompatibility) TorchvisionVersion() string { return version.StripModifier(c.Torchvision) } +func (c *TorchCompatibility) TorchaudioVersion() string { + return version.StripModifier(c.Torchaudio) +} + type CUDABaseImage struct { Tag string CUDA string @@ -287,12 +291,12 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { for _, compat := range TorchCompatibilityMatrix { if compat.TorchVersion() == ver && compat.CUDA == nil { - return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil + return TorchPackageName, torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil } } // Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions. - return "torch", ver, "", "", nil + return TorchPackageName, ver, "", "", nil } func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { @@ -327,20 +331,20 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr } if latest == nil { // We've already warned user if they're doing something stupid in validateAndCompleteCUDA() - return "torch", ver, "", "", nil + return TorchPackageName, ver, "", "", nil } - return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil + return TorchPackageName, version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil } func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { for _, compat := range TorchCompatibilityMatrix { if compat.TorchvisionVersion() == ver && compat.CUDA == nil { - return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil + return TorchvisionPackageName, torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil } } // Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions. - return "torchvision", ver, "", "", nil + return TorchvisionPackageName, ver, "", "", nil } func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { @@ -350,7 +354,7 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra var latest *TorchCompatibility for _, compat := range TorchCompatibilityMatrix { compat := compat - if compat.TorchvisionVersion() != ver || compat.CUDA == nil { + if !version.Matches(compat.TorchvisionVersion(), ver) || compat.CUDA == nil { continue } greater, err := versionGreater(*compat.CUDA, cuda) @@ -376,10 +380,59 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra if latest == nil { // TODO: can we suggest a CUDA version known to be compatible? console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver) - return "torchvision", ver, "", "", nil + return TorchvisionPackageName, ver, "", "", nil + } + + return TorchvisionPackageName, version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil +} + +func torchaudioGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { + var latest *TorchCompatibility + for _, compat := range TorchCompatibilityMatrix { + compat := compat + taVersion := compat.TorchaudioVersion() + if taVersion == "" { + continue + } + if !version.Matches(taVersion, ver) || compat.CUDA == nil { + continue + } + greater, err := versionGreater(*compat.CUDA, cuda) + if err != nil { + panic(fmt.Sprintf("Invalid CUDA version: %s", err)) + } + if greater { + continue + } + if latest == nil { + latest = &compat + } else { + greater, err := versionGreater(*compat.CUDA, *latest.CUDA) + if err != nil { + // should never happen + panic(fmt.Sprintf("Invalid CUDA version: %s", err)) + } + if greater { + latest = &compat + } + } } + if latest == nil { + console.Warnf("Cog doesn't know if CUDA %s is compatible with torchaudio %s. This might cause CUDA problems.", cuda, ver) + return TorchaudioPackageName, ver, "", "", nil + } + + return TorchaudioPackageName, version.StripModifier(latest.Torchaudio), latest.FindLinks, latest.ExtraIndexURL, nil +} - return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil +func torchaudioCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) { + for _, compat := range TorchCompatibilityMatrix { + if version.Matches(compat.TorchaudioVersion(), ver) && compat.CUDA == nil { + return TorchaudioPackageName, torchStripCPUSuffixForM1(compat.Torchaudio, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil + } + } + // Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions. + return TorchaudioPackageName, ver, "", "", nil } // aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html diff --git a/pkg/config/config.go b/pkg/config/config.go index 4a282deebd..86082568f4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -33,6 +33,9 @@ const ( MinimumMajorPythonVersion int = 3 MinimumMinorPythonVersion int = 8 MinimumMajorCudaVersion int = 11 + TorchPackageName = "torch" + TorchvisionPackageName = "torchvision" + TorchaudioPackageName = "torchaudio" ) type RunItem struct { @@ -175,15 +178,46 @@ func (c *Config) CUDABaseImageTag() (string, error) { } func (c *Config) TorchVersion() (string, bool) { - return c.pythonPackageVersion("torch") + torchVersion, found := c.pythonPackageVersion(TorchPackageName) + if !found { + // Can we determine the torch version based on other packages related to it? + tvVersion, tvVersionFound := c.TorchvisionVersion() + tvVersion = version.StripModifier(tvVersion) + if tvVersionFound { + for _, compat := range TorchCompatibilityMatrix { + if version.Equal(tvVersion, version.StripModifier(compat.Torchvision)) { + return compat.Torch, true + } + } + } + + taVersion, taVersionFound := c.TorchaudioVersion() + taVersion = version.StripModifier(taVersion) + if taVersionFound { + for _, compat := range TorchCompatibilityMatrix { + if version.Equal(taVersion, version.StripModifier(compat.Torchaudio)) { + return compat.Torch, true + } + } + } + } + return torchVersion, found } func (c *Config) TorchvisionVersion() (string, bool) { - return c.pythonPackageVersion("torchvision") + tvVersion, found := c.pythonPackageVersion(TorchvisionPackageName) + if found { + tvVersion = version.StripModifier(tvVersion) + } + return tvVersion, found } func (c *Config) TorchaudioVersion() (string, bool) { - return c.pythonPackageVersion("torchaudio") + taVersion, found := c.pythonPackageVersion(TorchaudioPackageName) + if found { + taVersion = version.StripModifier(taVersion) + } + return taVersion, found } func (c *Config) TensorFlowVersion() (string, bool) { @@ -380,8 +414,12 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) { name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg) if err != nil { - // It's not pinned, so just return the line verbatim - return pkg, []string{}, []string{}, nil + // It's not pinned, so just return the line verbatim, unless its one of our special packages + if pkg == TorchPackageName || pkg == TorchvisionPackageName || pkg == TorchaudioPackageName { + name = pkg + } else { + return pkg, []string{}, []string{}, nil + } } if len(extraIndexURLs) > 0 { return name + "==" + version, findLinksList, extraIndexURLs, nil @@ -398,7 +436,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s } } // There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed - case "torch": + case TorchPackageName: + torchVersion, found := c.TorchVersion() + if found { + version = torchVersion + } if c.Build.GPU { name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA) if err != nil { @@ -410,7 +452,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s return "", nil, nil, err } } - case "torchvision": + case TorchvisionPackageName: + tvVersion, found := c.TorchvisionVersion() + if found { + version = tvVersion + } if c.Build.GPU { name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA) if err != nil { @@ -422,6 +468,22 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s return "", nil, nil, err } } + case TorchaudioPackageName: + taVersion, found := c.TorchaudioVersion() + if found { + version = taVersion + } + if c.Build.GPU { + name, version, findLinks, extraIndexURL, err = torchaudioGPUPackage(version, c.Build.CUDA) + if err != nil { + return "", nil, nil, err + } + } else { + name, version, findLinks, extraIndexURL, err = torchaudioCPUPackage(version, goos, goarch) + if err != nil { + return "", nil, nil, err + } + } } pkgWithVersion := name if version != "" { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index d0ec825ab8..a38479b0a4 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -718,3 +718,89 @@ torch==2.4.0 torchvision==2.4.0` require.Equal(t, expected, requirements) } + +func TestResolveTorchVersionWithTorchVisionDependency(t *testing.T) { + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.8", + PythonPackages: []string{ + "torch", + "torchvision==0.19.0+cu121", + }, + CUDA: "12.1", + }, + } + err := config.ValidateAndComplete("") + require.NoError(t, err) + torchVersion, found := config.TorchVersion() + require.True(t, found) + require.Equal(t, "2.4.0", torchVersion) +} + +func TestResolveTorchVersionWithTorchAudioDependency(t *testing.T) { + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.8", + PythonPackages: []string{ + "torch", + "torchaudio==2.4.0", + }, + CUDA: "12.1", + }, + } + err := config.ValidateAndComplete("") + require.NoError(t, err) + torchVersion, found := config.TorchVersion() + require.True(t, found) + require.Equal(t, "2.4.0", torchVersion) +} + +func TestPythonRequirementsWithTorchvisionVersionModifierStripped(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch +torchvision==0.19.0+cu121`), 0o644) + require.NoError(t, err) + + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.8", + PythonRequirements: "requirements.txt", + }, + } + err = config.ValidateAndComplete(tmpDir) + require.NoError(t, err) + + requirements, err := config.PythonRequirementsForArch("", "", []string{}) + require.NoError(t, err) + expected := `--extra-index-url https://download.pytorch.org/whl/cu124 +torch==2.4.0 +torchvision==0.19.0` + require.Equal(t, expected, requirements) +} + +func TestPythonRequirementsWithTorchaudio(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch +torchaudio==2.4.0`), 0o644) + require.NoError(t, err) + + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.8", + PythonRequirements: "requirements.txt", + }, + } + err = config.ValidateAndComplete(tmpDir) + require.NoError(t, err) + + requirements, err := config.PythonRequirementsForArch("", "", []string{}) + require.NoError(t, err) + expected := `--extra-index-url https://download.pytorch.org/whl/cu124 +torch==2.4.0 +torchaudio==2.4.0` + require.Equal(t, expected, requirements) +} diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 2ad4b47677..64c9701fdd 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -226,7 +226,7 @@ func (g *BaseImageGenerator) pythonPackages() []string { continue } - pkgs = append(pkgs, "torchvision=="+compat.Torchvision) + pkgs = append(pkgs, "torchvision=="+version.StripModifier(compat.Torchvision)) break } @@ -239,7 +239,7 @@ func (g *BaseImageGenerator) pythonPackages() []string { continue } - pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio) + pkgs = append(pkgs, "torchaudio=="+version.StripModifier(compat.Torchaudio)) break } diff --git a/test-integration/test_integration/fixtures/torch-torchvision-project/cog.yaml b/test-integration/test_integration/fixtures/torch-torchvision-project/cog.yaml new file mode 100644 index 0000000000..276a2f4d81 --- /dev/null +++ b/test-integration/test_integration/fixtures/torch-torchvision-project/cog.yaml @@ -0,0 +1,7 @@ +build: + gpu: true + python_version: "3.9" + python_packages: + - "torch" + - "torchvision==0.19.0+cu121" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/torch-torchvision-project/predict.py b/test-integration/test_integration/fixtures/torch-torchvision-project/predict.py new file mode 100644 index 0000000000..44f6992b01 --- /dev/null +++ b/test-integration/test_integration/fixtures/torch-torchvision-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index a1f94f8171..db394469e0 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -296,3 +296,18 @@ def test_torch_1_13_0_base_image_fail_explicit(docker_image): capture_output=True, ) assert build_process.returncode == 0 + + +def test_torch_version_resolution_with_torchvision(docker_image): + project_dir = Path(__file__).parent / "fixtures/torch-torchvision-project" + build_process = subprocess.run( + [ + "cog", + "build", + "-t", + docker_image, + ], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 0 From 52310527e586a47588bdcd1d468827069338a81d Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 13 Sep 2024 13:09:01 -0400 Subject: [PATCH 2/2] Bump integration test timeout to 20 minutes * We now have enough integration tests that this is becoming an issue. --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index da9a02f18e..56c7f6bc2d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -118,7 +118,7 @@ jobs: name: "Test integration" needs: build-python runs-on: ubuntu-latest-16-cores - timeout-minutes: 10 + timeout-minutes: 20 steps: - uses: actions/checkout@v4 with: