Skip to content

Commit

Permalink
[fix] use torch cpu instead of cuda by default (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehri01 authored Dec 20, 2024
1 parent 09a4de8 commit 8f4d06f
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions internal/backends/python/python.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"regexp"
"runtime"
"strings"

"github.com/BurntSushi/toml"
Expand All @@ -21,6 +22,23 @@ import (

var normalizationPattern = regexp.MustCompile(`[-_.]+`)

type extraIndex struct {
// url is the location of the index
url string
// os is the operating system to override the index for, leave empty
// to override on any operating system
os string
}

var torchCpu = extraIndex{
url: "https://download.pytorch.org/whl/cpu",
os: "linux",
}

var extraIndexMap = map[string][]extraIndex{
"torch": {torchCpu},
}

// this generates a mapping of pypi packages <-> modules
// moduleToPypiPackage pypiPackageToModules are provided
// pypiEntryInfoResponse is a wrapper around pypiEntryInfo
Expand Down Expand Up @@ -751,6 +769,22 @@ func makePythonUvBackend() api.LanguageBackend {

return pkgs
}
addExtraIndexes := func(pkgName string) {
extraIndexes, ok := extraIndexMap[pkgName]
if ok {
uvIndex := os.Getenv("UV_INDEX")

for _, index := range extraIndexes {
if strings.HasPrefix(runtime.GOOS, index.os) {
uvIndex = index.url + " " + uvIndex
}
}

os.Setenv("UV_INDEX", uvIndex)
}
os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match")
}

b := api.LanguageBackend{
Name: "python3-uv",
Specfile: "pyproject.toml",
Expand Down Expand Up @@ -833,7 +867,14 @@ func makePythonUvBackend() api.LanguageBackend {
}

cmd = append(cmd, pep440Join(coords))
addExtraIndexes(string(name))
}

specPkgs := listUvSpecfile()
for pkg := range specPkgs {
addExtraIndexes(string(pkg))
}

util.RunCmd(cmd)
},
Lock: func(ctx context.Context) {
Expand All @@ -858,6 +899,11 @@ func makePythonUvBackend() api.LanguageBackend {
span, ctx := tracer.StartSpanFromContext(ctx, "uv install")
defer span.Finish()

pkgs := listUvSpecfile()
for pkg := range pkgs {
addExtraIndexes(string(pkg))
}

util.RunCmd([]string{"uv", "sync"})
},
ListSpecfile: func(mergeAllGroups bool) map[api.PkgName]api.PkgSpec {
Expand Down

0 comments on commit 8f4d06f

Please sign in to comment.