Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Torch Dependency Resolution #1952

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
name: "Test integration"
needs: build-python
runs-on: ubuntu-latest-16-cores
timeout-minutes: 10
timeout-minutes: 20
steps:
- uses: actions/checkout@v4
with:
Expand Down
71 changes: 62 additions & 9 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ func (c *TorchCompatibility) TorchvisionVersion() string {
return version.StripModifier(c.Torchvision)
}

func (c *TorchCompatibility) TorchaudioVersion() string {
return version.StripModifier(c.Torchaudio)
}

type CUDABaseImage struct {
Tag string
CUDA string
Expand Down Expand Up @@ -287,12 +291,12 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err
func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchVersion() == ver && compat.CUDA == nil {
return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
return TorchPackageName, torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}

// Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions.
return "torch", ver, "", "", nil
return TorchPackageName, ver, "", "", nil
}

func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
Expand Down Expand Up @@ -327,20 +331,20 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
}
if latest == nil {
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
return "torch", ver, "", "", nil
return TorchPackageName, ver, "", "", nil
}

return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
return TorchPackageName, version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
return TorchvisionPackageName, torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return "torchvision", ver, "", "", nil
return TorchvisionPackageName, ver, "", "", nil
}

func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
Expand All @@ -350,7 +354,7 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
var latest *TorchCompatibility
for _, compat := range TorchCompatibilityMatrix {
compat := compat
if compat.TorchvisionVersion() != ver || compat.CUDA == nil {
if !version.Matches(compat.TorchvisionVersion(), ver) || compat.CUDA == nil {
continue
}
greater, err := versionGreater(*compat.CUDA, cuda)
Expand All @@ -376,10 +380,59 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
if latest == nil {
// TODO: can we suggest a CUDA version known to be compatible?
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver)
return "torchvision", ver, "", "", nil
return TorchvisionPackageName, ver, "", "", nil
}

return TorchvisionPackageName, version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchaudioGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
var latest *TorchCompatibility
for _, compat := range TorchCompatibilityMatrix {
compat := compat
taVersion := compat.TorchaudioVersion()
if taVersion == "" {
continue
}
if !version.Matches(taVersion, ver) || compat.CUDA == nil {
continue
}
greater, err := versionGreater(*compat.CUDA, cuda)
if err != nil {
panic(fmt.Sprintf("Invalid CUDA version: %s", err))
}
if greater {
continue
}
if latest == nil {
latest = &compat
} else {
greater, err := versionGreater(*compat.CUDA, *latest.CUDA)
if err != nil {
// should never happen
panic(fmt.Sprintf("Invalid CUDA version: %s", err))
}
if greater {
latest = &compat
}
}
}
if latest == nil {
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchaudio %s. This might cause CUDA problems.", cuda, ver)
return TorchaudioPackageName, ver, "", "", nil
}

return TorchaudioPackageName, version.StripModifier(latest.Torchaudio), latest.FindLinks, latest.ExtraIndexURL, nil
}

return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
func torchaudioCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if version.Matches(compat.TorchaudioVersion(), ver) && compat.CUDA == nil {
return TorchaudioPackageName, torchStripCPUSuffixForM1(compat.Torchaudio, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return TorchaudioPackageName, ver, "", "", nil
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
Expand Down
76 changes: 69 additions & 7 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ const (
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMajorCudaVersion int = 11
TorchPackageName = "torch"
TorchvisionPackageName = "torchvision"
TorchaudioPackageName = "torchaudio"
)

type RunItem struct {
Expand Down Expand Up @@ -175,15 +178,46 @@ func (c *Config) CUDABaseImageTag() (string, error) {
}

func (c *Config) TorchVersion() (string, bool) {
return c.pythonPackageVersion("torch")
torchVersion, found := c.pythonPackageVersion(TorchPackageName)
if !found {
// Can we determine the torch version based on other packages related to it?
tvVersion, tvVersionFound := c.TorchvisionVersion()
tvVersion = version.StripModifier(tvVersion)
if tvVersionFound {
for _, compat := range TorchCompatibilityMatrix {
if version.Equal(tvVersion, version.StripModifier(compat.Torchvision)) {
return compat.Torch, true
}
}
}

taVersion, taVersionFound := c.TorchaudioVersion()
taVersion = version.StripModifier(taVersion)
if taVersionFound {
for _, compat := range TorchCompatibilityMatrix {
if version.Equal(taVersion, version.StripModifier(compat.Torchaudio)) {
return compat.Torch, true
}
}
}
}
return torchVersion, found
}

func (c *Config) TorchvisionVersion() (string, bool) {
return c.pythonPackageVersion("torchvision")
tvVersion, found := c.pythonPackageVersion(TorchvisionPackageName)
if found {
tvVersion = version.StripModifier(tvVersion)
}
return tvVersion, found
}

func (c *Config) TorchaudioVersion() (string, bool) {
return c.pythonPackageVersion("torchaudio")
taVersion, found := c.pythonPackageVersion(TorchaudioPackageName)
if found {
taVersion = version.StripModifier(taVersion)
}
return taVersion, found
}

func (c *Config) TensorFlowVersion() (string, bool) {
Expand Down Expand Up @@ -380,8 +414,12 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
// It's not pinned, so just return the line verbatim, unless its one of our special packages
if pkg == TorchPackageName || pkg == TorchvisionPackageName || pkg == TorchaudioPackageName {
name = pkg
} else {
return pkg, []string{}, []string{}, nil
}
}
if len(extraIndexURLs) > 0 {
return name + "==" + version, findLinksList, extraIndexURLs, nil
Expand All @@ -398,7 +436,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
}
}
// There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed
case "torch":
case TorchPackageName:
torchVersion, found := c.TorchVersion()
if found {
version = torchVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA)
if err != nil {
Expand All @@ -410,7 +452,11 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
return "", nil, nil, err
}
}
case "torchvision":
case TorchvisionPackageName:
tvVersion, found := c.TorchvisionVersion()
if found {
version = tvVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA)
if err != nil {
Expand All @@ -422,6 +468,22 @@ func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage s
return "", nil, nil, err
}
}
case TorchaudioPackageName:
taVersion, found := c.TorchaudioVersion()
if found {
version = taVersion
}
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchaudioGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
} else {
name, version, findLinks, extraIndexURL, err = torchaudioCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
}
}
pkgWithVersion := name
if version != "" {
Expand Down
86 changes: 86 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,89 @@ torch==2.4.0
torchvision==2.4.0`
require.Equal(t, expected, requirements)
}

func TestResolveTorchVersionWithTorchVisionDependency(t *testing.T) {
config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonPackages: []string{
"torch",
"torchvision==0.19.0+cu121",
},
CUDA: "12.1",
},
}
err := config.ValidateAndComplete("")
require.NoError(t, err)
torchVersion, found := config.TorchVersion()
require.True(t, found)
require.Equal(t, "2.4.0", torchVersion)
}

func TestResolveTorchVersionWithTorchAudioDependency(t *testing.T) {
config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonPackages: []string{
"torch",
"torchaudio==2.4.0",
},
CUDA: "12.1",
},
}
err := config.ValidateAndComplete("")
require.NoError(t, err)
torchVersion, found := config.TorchVersion()
require.True(t, found)
require.Equal(t, "2.4.0", torchVersion)
}

func TestPythonRequirementsWithTorchvisionVersionModifierStripped(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch
torchvision==0.19.0+cu121`), 0o644)
require.NoError(t, err)

config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
err = config.ValidateAndComplete(tmpDir)
require.NoError(t, err)

requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.4.0
torchvision==0.19.0`
require.Equal(t, expected, requirements)
}

func TestPythonRequirementsWithTorchaudio(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch
torchaudio==2.4.0`), 0o644)
require.NoError(t, err)

config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
err = config.ValidateAndComplete(tmpDir)
require.NoError(t, err)

requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.4.0
torchaudio==2.4.0`
require.Equal(t, expected, requirements)
}
4 changes: 2 additions & 2 deletions pkg/dockerfile/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (g *BaseImageGenerator) pythonPackages() []string {
continue
}

pkgs = append(pkgs, "torchvision=="+compat.Torchvision)
pkgs = append(pkgs, "torchvision=="+version.StripModifier(compat.Torchvision))
break
}

Expand All @@ -239,7 +239,7 @@ func (g *BaseImageGenerator) pythonPackages() []string {
continue
}

pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio)
pkgs = append(pkgs, "torchaudio=="+version.StripModifier(compat.Torchaudio))
break
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
build:
gpu: true
python_version: "3.9"
python_packages:
- "torch"
- "torchvision==0.19.0+cu121"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, s: str) -> str:
return "hello " + s
Loading