forked from bodaay/HuggingFaceModelDownloader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
203 lines (172 loc) · 6.75 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package main
import (
"errors"
"fmt"
hfd "hfdownloader/hfdownloader"
"io"
"os/exec"
"path"
"path/filepath"
"runtime"
"time"
"log"
"os"
"github.com/joho/godotenv"
"github.com/spf13/cobra"
)
const VERSION = "1.2.9"
func main() {
var (
modelName string
datasetName string
branch string
storage string
numberOfConcurrentConnections int
HuggingFaceAccessToken string
OneFolderPerFilter bool
SkipSHA bool
install bool
installPath string
maxRetries int
retryInterval int
)
ShortString := fmt.Sprintf("a Simple HuggingFace Models Downloader Utility\nVersion: %s", VERSION)
currentPath, err := os.Executable()
if err != nil {
log.Printf("Failed to get execuable path, %s", err)
}
if currentPath != "" {
ShortString = fmt.Sprintf("%s\nRunning on: %s", ShortString, currentPath)
}
rootCmd := &cobra.Command{
Use: "hfdownloader",
Short: ShortString,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate the ModelName parameter
// if !hfdn.IsValidModelName(modelName) { Just realized there are indeed models that don't follow this format :)
// // fmt.Println("Error:", err)
// return fmt.Errorf("Invailid Model Name, it should follow the pattern: ModelAuthor/ModelName")
// }
if install {
err := installBinary(installPath)
if err != nil {
log.Fatal(err)
}
os.Exit(0)
}
var IsDataset bool
if (modelName == "" && datasetName == "") || (modelName != "" && datasetName != "") {
cmd.Help()
return fmt.Errorf("Error: You must set either modelName or datasetName, not both or neither.")
}
ModelOrDataSet := modelName
// Print the parameter values
if modelName != "" {
fmt.Println("Model:", modelName)
IsDataset = false //no need to speicfy it here, just cleaner
ModelOrDataSet = modelName
}
if datasetName != "" {
fmt.Println("Dataset:", datasetName)
IsDataset = true
ModelOrDataSet = datasetName
}
_ = godotenv.Load() //this will give an error of the file is not there, but we dont really care
// Fetch token from command line flag or from .env file if not provided in flag
if HuggingFaceAccessToken == "" {
HuggingFaceAccessToken = os.Getenv("HUGGING_FACE_HUB_TOKEN")
}
fmt.Println("Branch:", branch)
fmt.Println("Storage:", storage)
fmt.Println("NumberOfConcurrentConnections:", numberOfConcurrentConnections)
fmt.Println("Append Filter Names to Folder:", OneFolderPerFilter)
fmt.Println("Skip SHA256 Check:", SkipSHA)
fmt.Println("Token:", HuggingFaceAccessToken)
var downloadErr error
for i := 0; i < maxRetries; i++ {
downloadErr = hfd.DownloadModel(ModelOrDataSet, OneFolderPerFilter, SkipSHA, IsDataset, storage, branch, numberOfConcurrentConnections, HuggingFaceAccessToken)
if downloadErr != nil {
fmt.Printf("warning: attempt %d / %d failed, error: %s\n", i+1, maxRetries, downloadErr.Error())
time.Sleep(time.Duration(retryInterval) * time.Second)
continue
} else {
break
}
}
if downloadErr != nil {
return fmt.Errorf("failed to download %s after %d attempts, error: %s", ModelOrDataSet, maxRetries, downloadErr.Error())
}
fmt.Printf("\nDownload of %s completed successfully\n", ModelOrDataSet)
return nil
},
}
rootCmd.SilenceUsage = true // I'll manually print help them while validating the parameters above
rootCmd.Flags().SortFlags = false
// Define flags for command-line parameters
rootCmd.Flags().StringVarP(&modelName, "model", "m", "", "Model/Dataset name (required if dataset not set)\nYou can supply filters for required LFS model files\nex: ModelName:q4_0,q8_1\nex: TheBloke/WizardLM-Uncensored-Falcon-7B-GGML:fp16")
rootCmd.Flags().StringVarP(&datasetName, "dataset", "d", "", "Model/Dataset name (required if model not set)")
rootCmd.Flags().StringVarP(&branch, "branch", "b", "main", "ModModel/Datasetel branch (optional)")
rootCmd.Flags().StringVarP(&storage, "storage", "s", "Storage", "Storage path (optional)")
rootCmd.Flags().BoolVarP(&SkipSHA, "skipSHA", "k", false, "Skip SHA256 Hash Check, sometimes you just need to download missing files without wasting time waiting (optional)")
rootCmd.Flags().BoolVarP(&OneFolderPerFilter, "appendFilterFolder", "f", false, "This will append the filter name to the folder, use it for GGML qunatizatized filterd download only (optional)")
rootCmd.Flags().IntVarP(&numberOfConcurrentConnections, "concurrent", "c", 5, "Number of LFS concurrent connections (optional)")
rootCmd.Flags().StringVarP(&HuggingFaceAccessToken, "token", "t", "", "HuggingFace Access Token, this can be automatically supplied by env variable 'HUGGING_FACE_HUB_TOKEN' or .env file, required for some Models/Datasets, you still need to manually accept agreement if model requires it (optional)")
rootCmd.Flags().BoolVarP(&install, "install", "i", false, "Install the binary to the OS default bin folder, Unix-like operating systems only")
rootCmd.Flags().StringVarP(&installPath, "installPath", "p", "/usr/local/bin/", "install Path (optional)")
rootCmd.Flags().IntVar(&maxRetries, "maxRetries", 3, "Max number of retries (optional)")
rootCmd.Flags().IntVar(&retryInterval, "retryInterval", 5, "Retry interval in seconds (optional)")
if err := rootCmd.Execute(); err != nil {
log.Fatalln("Error:", err)
}
os.Exit(0)
}
func installBinary(installPath string) error {
if runtime.GOOS == "windows" {
return errors.New("the install command is not supported on Windows")
}
exePath, err := os.Executable()
if err != nil {
return err
}
dst := path.Join(installPath, filepath.Base(exePath))
// Check if the binary already exists and remove it
if _, err := os.Stat(dst); err == nil {
os.Remove(dst)
}
// Open source file
srcFile, err := os.Open(exePath)
if err != nil {
return err
}
defer srcFile.Close()
// Try to copy the file
err = copyFile(dst, srcFile)
if err != nil {
if os.IsPermission(err) {
// If permission error, try to elevate privilege
fmt.Printf("Require sudo privilages to install to: %s\n", installPath)
cmd := exec.Command("sudo", "cp", exePath, dst)
if err := cmd.Run(); err != nil {
return err
}
} else {
return err
}
}
log.Printf("The binary has been copied to %s", dst)
return nil
}
// copyFile is a helper function to copy a file with specific permission
func copyFile(dst string, src *os.File) error {
// Open destination file and ensure it gets closed
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
if err != nil {
return err
}
defer dstFile.Close()
// Copy the file content
if _, err := io.Copy(dstFile, src); err != nil {
return err
}
return nil
}