Skip to content

Commit

Permalink
feat: add batch delete tasks api (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
monkeyWie authored Jan 19, 2024
1 parent a1a8abc commit 38ff711
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 14 deletions.
64 changes: 64 additions & 0 deletions pkg/download/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ func (d *Downloader) Delete(id string, force bool) (err error) {
for i, t := range d.tasks {
if t.ID == id {
d.tasks = append(d.tasks[:i], d.tasks[i+1:]...)
break
}
}
for i, t := range d.waitTasks {
if t.ID == id {
d.waitTasks = append(d.waitTasks[:i], d.waitTasks[i+1:]...)
break
}
}
}()
Expand All @@ -460,6 +467,47 @@ func (d *Downloader) Delete(id string, force bool) (err error) {
return
}

func (d *Downloader) DeleteByStatues(statues []base.Status, force bool) (err error) {
deleteTasks := d.GetTasksByStatues(statues)
if len(deleteTasks) == 0 {
return
}

deleteIds := make([]string, 0)
for _, task := range deleteTasks {
deleteIds = append(deleteIds, task.ID)
}
func() {
d.lock.Lock()
defer d.lock.Unlock()

for _, id := range deleteIds {
for i, t := range d.tasks {
if t.ID == id {
d.tasks = append(d.tasks[:i], d.tasks[i+1:]...)
break
}
}
for i, t := range d.waitTasks {
if t.ID == id {
d.waitTasks = append(d.waitTasks[:i], d.waitTasks[i+1:]...)
break
}
}
}
}()

for _, task := range deleteTasks {
err = d.doDelete(task, force)
if err != nil {
return
}
}

d.notifyRunning()
return
}

func (d *Downloader) doDelete(task *Task, force bool) (err error) {
err = func() error {
if task.fetcher != nil {
Expand Down Expand Up @@ -548,6 +596,22 @@ func (d *Downloader) GetTasks() []*Task {
return d.tasks
}

func (d *Downloader) GetTasksByStatues(statues []base.Status) []*Task {
if len(statues) == 0 {
return d.tasks
}
tasks := make([]*Task, 0)
for _, task := range d.tasks {
for _, status := range statues {
if task.Status == status {
tasks = append(tasks, task)
break
}
}
}
return tasks
}

func (d *Downloader) GetConfig() (*DownloaderStoreConfig, error) {
return d.cfg.DownloaderStoreConfig, nil
}
Expand Down
35 changes: 21 additions & 14 deletions pkg/rest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ func DeleteTask(w http.ResponseWriter, r *http.Request) {
WriteJson(w, model.NewNilResult())
}

func DeleteTasks(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
status := r.Form["status"]
force := r.FormValue("force")
if err := Downloader.DeleteByStatues(convertStatues(status), force == "true"); err != nil {
WriteJson(w, model.NewErrorResult(err.Error()))
return
}
WriteJson(w, model.NewNilResult())
}

func GetTask(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
taskId := vars["id"]
Expand All @@ -122,20 +133,8 @@ func GetTask(w http.ResponseWriter, r *http.Request) {
func GetTasks(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
status := r.Form["status"]
tasks := Downloader.GetTasks()
if len(status) == 0 {
WriteJson(w, model.NewOkResult(tasks))
return
}
result := make([]*download.Task, 0)
for _, task := range tasks {
for _, s := range status {
if task.Status == base.Status(s) {
result = append(result, task)
}
}
}
WriteJson(w, model.NewOkResult(result))
tasks := Downloader.GetTasksByStatues(convertStatues(status))
WriteJson(w, model.NewOkResult(tasks))
}

func GetConfig(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -282,6 +281,14 @@ func DoProxy(w http.ResponseWriter, r *http.Request) {
w.Write(buf)
}

func convertStatues(statues []string) []base.Status {
result := make([]base.Status, 0)
for _, status := range statues {
result = append(result, base.Status(status))
}
return result
}

func writeError(w http.ResponseWriter, msg string) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(msg))
Expand Down
1 change: 1 addition & 0 deletions pkg/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func BuildServer(startCfg *model.StartConfig) (*http.Server, net.Listener, error
r.Methods(http.MethodPut).Path("/api/v1/tasks/pause").HandlerFunc(PauseAllTask)
r.Methods(http.MethodPut).Path("/api/v1/tasks/continue").HandlerFunc(ContinueAllTask)
r.Methods(http.MethodDelete).Path("/api/v1/tasks/{id}").HandlerFunc(DeleteTask)
r.Methods(http.MethodDelete).Path("/api/v1/tasks").HandlerFunc(DeleteTasks)
r.Methods(http.MethodGet).Path("/api/v1/tasks/{id}").HandlerFunc(GetTask)
r.Methods(http.MethodGet).Path("/api/v1/tasks").HandlerFunc(GetTasks)
r.Methods(http.MethodGet).Path("/api/v1/config").HandlerFunc(GetConfig)
Expand Down
52 changes: 52 additions & 0 deletions pkg/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,58 @@ func TestDeleteTaskForce(t *testing.T) {
})
}

func TestDeleteAllTasks(t *testing.T) {
doTest(func() {
taskCount := 3

var wg sync.WaitGroup
wg.Add(taskCount)
Downloader.Listener(func(event *download.Event) {
if event.Key == download.EventKeyFinally {
wg.Done()
}
})

for i := 0; i < taskCount; i++ {
httpRequestCheckOk[string](http.MethodPost, "/api/v1/tasks", createReq)
}

wg.Wait()

httpRequestCheckOk[any](http.MethodDelete, "/api/v1/tasks?force=true", nil)
tasks := httpRequestCheckOk[[]*download.Task](http.MethodGet, "/api/v1/tasks", nil)
if len(tasks) != 0 {
t.Errorf("DeleteTasks() got = %v, want %v", len(tasks), 0)
}
})
}

func TestDeleteTasksByStatues(t *testing.T) {
doTest(func() {
taskCount := 3

var wg sync.WaitGroup
wg.Add(taskCount)
Downloader.Listener(func(event *download.Event) {
if event.Key == download.EventKeyFinally {
wg.Done()
}
})

for i := 0; i < taskCount; i++ {
httpRequestCheckOk[string](http.MethodPost, "/api/v1/tasks", createReq)
}

wg.Wait()

httpRequestCheckOk[any](http.MethodDelete, fmt.Sprintf("/api/v1/tasks?status=%s&force=true", base.DownloadStatusDone), nil)
tasks := httpRequestCheckOk[[]*download.Task](http.MethodGet, "/api/v1/tasks", nil)
if len(tasks) != 0 {
t.Errorf("DeleteTasks() got = %v, want %v", len(tasks), 0)
}
})
}

func TestGetTasks(t *testing.T) {
doTest(func() {
var wg sync.WaitGroup
Expand Down

0 comments on commit 38ff711

Please sign in to comment.