feat: add cancel download

This commit is contained in:
putyy
2025-09-13 22:19:42 +08:00
committed by putyy
parent 55d3f06cb6
commit 2d75bbb5c3
10 changed files with 129 additions and 8 deletions

View File

@@ -1,6 +1,7 @@
package core
import (
"context"
"fmt"
"io"
"net/http"
@@ -48,9 +49,12 @@ type FileDownloader struct {
Headers map[string]string
DownloadTaskList []*DownloadTask
progressCallback ProgressCallback
ctx context.Context
cancelFunc context.CancelFunc
}
func NewFileDownloader(url, filename string, totalTasks int, headers map[string]string) *FileDownloader {
ctx, cancelFunc := context.WithCancel(context.Background())
return &FileDownloader{
Url: url,
FileName: filename,
@@ -60,6 +64,8 @@ func NewFileDownloader(url, filename string, totalTasks int, headers map[string]
TotalSize: 0,
Headers: headers,
DownloadTaskList: make([]*DownloadTask, 0),
ctx: ctx,
cancelFunc: cancelFunc,
}
}
@@ -271,11 +277,21 @@ func (fd *FileDownloader) startDownloadTask(wg *sync.WaitGroup, progressChan cha
return
}
if strings.Contains(err.Error(), "cancelled") {
errorChan <- err
return
}
task.err = err
globalLogger.Warn().Msgf("Task %d failed (attempt %d/%d): %v", task.taskID, retries+1, MaxRetries, err)
if retries < MaxRetries-1 {
time.Sleep(RetryDelay)
select {
case <-fd.ctx.Done():
errorChan <- fmt.Errorf("task %d cancelled during retry", task.taskID)
return
case <-time.After(RetryDelay):
}
}
}
@@ -283,7 +299,13 @@ func (fd *FileDownloader) startDownloadTask(wg *sync.WaitGroup, progressChan cha
}
func (fd *FileDownloader) doDownloadTask(progressChan chan ProgressChan, task *DownloadTask) error {
request, err := http.NewRequest("GET", fd.Url, nil)
select {
case <-fd.ctx.Done():
return fmt.Errorf("download cancelled")
default:
}
request, err := http.NewRequestWithContext(fd.ctx, "GET", fd.Url, nil)
if err != nil {
return fmt.Errorf("create request failed: %w", err)
}
@@ -310,6 +332,12 @@ func (fd *FileDownloader) doDownloadTask(progressChan chan ProgressChan, task *D
buf := make([]byte, 32*1024)
for {
select {
case <-fd.ctx.Done():
return fmt.Errorf("download cancelled")
default:
}
n, err := resp.Body.Read(buf)
if n > 0 {
writeSize := int64(n)
@@ -353,9 +381,9 @@ func (fd *FileDownloader) verifyDownload() error {
return nil
}
func (fd *FileDownloader) Start() (*FileDownloader, error) {
func (fd *FileDownloader) Start() error {
if err := fd.init(); err != nil {
return nil, err
return err
}
fd.createDownloadTasks()
@@ -365,5 +393,19 @@ func (fd *FileDownloader) Start() (*FileDownloader, error) {
fd.File.Close()
}
return fd, err
return err
}
func (fd *FileDownloader) Cancel() {
if fd.cancelFunc != nil {
fd.cancelFunc()
}
if fd.File != nil {
fd.File.Close()
}
if fd.FileName != "" {
_ = os.Remove(fd.FileName)
}
}

View File

@@ -345,6 +345,24 @@ func (h *HttpServer) download(w http.ResponseWriter, r *http.Request) {
h.success(w)
}
func (h *HttpServer) cancel(w http.ResponseWriter, r *http.Request) {
var data struct {
MediaInfo
}
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.error(w, err.Error())
return
}
err := resourceOnce.cancel(data.Id)
if err != nil {
h.error(w, err.Error())
return
}
h.success(w)
}
func (h *HttpServer) wxFileDecode(w http.ResponseWriter, r *http.Request) {
var data struct {
MediaInfo

View File

@@ -56,6 +56,8 @@ func HandleApi(w http.ResponseWriter, r *http.Request) bool {
httpServerOnce.delete(w, r)
case "/api/download":
httpServerOnce.download(w, r)
case "/api/cancel":
httpServerOnce.cancel(w, r)
case "/api/wx-file-decode":
httpServerOnce.wxFileDecode(w, r)
case "/api/batch-export":

View File

@@ -3,6 +3,7 @@ package core
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -22,6 +23,7 @@ type WxFileDecodeResult struct {
type Resource struct {
mediaMark sync.Map
tasks sync.Map
resType map[string]bool
resTypeMux sync.RWMutex
}
@@ -86,6 +88,15 @@ func (r *Resource) delete(sign string) {
r.mediaMark.Delete(sign)
}
func (r *Resource) cancel(id string) error {
if d, ok := r.tasks.Load(id); ok {
d.(*FileDownloader).Cancel()
r.tasks.Delete(id) // 可选:取消后清理
return nil
}
return errors.New("task not found")
}
func (r *Resource) download(mediaInfo MediaInfo, decodeStr string) {
if globalConfig.SaveDirectory == "" {
return
@@ -149,10 +160,13 @@ func (r *Resource) download(mediaInfo MediaInfo, decodeStr string) {
downloader.progressCallback = func(totalDownloaded, totalSize float64, taskID int, taskProgress float64) {
r.progressEventsEmit(mediaInfo, strconv.Itoa(int(totalDownloaded*100/totalSize))+"%", shared.DownloadStatusRunning)
}
fd, err := downloader.Start()
mediaInfo.SavePath = fd.FileName
r.tasks.Store(mediaInfo.Id, downloader)
err := downloader.Start()
mediaInfo.SavePath = downloader.FileName
if err != nil {
r.progressEventsEmit(mediaInfo, err.Error())
if !strings.Contains(err.Error(), "cancelled") {
r.progressEventsEmit(mediaInfo, err.Error())
}
return
}
if decodeStr != "" {