diff --git a/core/downloader.go b/core/downloader.go index f82ddc8..8f73c6d 100644 --- a/core/downloader.go +++ b/core/downloader.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "io" "net/http" @@ -48,9 +49,12 @@ type FileDownloader struct { Headers map[string]string DownloadTaskList []*DownloadTask progressCallback ProgressCallback + ctx context.Context + cancelFunc context.CancelFunc } func NewFileDownloader(url, filename string, totalTasks int, headers map[string]string) *FileDownloader { + ctx, cancelFunc := context.WithCancel(context.Background()) return &FileDownloader{ Url: url, FileName: filename, @@ -60,6 +64,8 @@ func NewFileDownloader(url, filename string, totalTasks int, headers map[string] TotalSize: 0, Headers: headers, DownloadTaskList: make([]*DownloadTask, 0), + ctx: ctx, + cancelFunc: cancelFunc, } } @@ -271,11 +277,21 @@ func (fd *FileDownloader) startDownloadTask(wg *sync.WaitGroup, progressChan cha return } + if strings.Contains(err.Error(), "cancelled") { + errorChan <- err + return + } + task.err = err globalLogger.Warn().Msgf("Task %d failed (attempt %d/%d): %v", task.taskID, retries+1, MaxRetries, err) if retries < MaxRetries-1 { - time.Sleep(RetryDelay) + select { + case <-fd.ctx.Done(): + errorChan <- fmt.Errorf("task %d cancelled during retry", task.taskID) + return + case <-time.After(RetryDelay): + } } } @@ -283,7 +299,13 @@ func (fd *FileDownloader) startDownloadTask(wg *sync.WaitGroup, progressChan cha } func (fd *FileDownloader) doDownloadTask(progressChan chan ProgressChan, task *DownloadTask) error { - request, err := http.NewRequest("GET", fd.Url, nil) + select { + case <-fd.ctx.Done(): + return fmt.Errorf("download cancelled") + default: + } + + request, err := http.NewRequestWithContext(fd.ctx, "GET", fd.Url, nil) if err != nil { return fmt.Errorf("create request failed: %w", err) } @@ -310,6 +332,12 @@ func (fd *FileDownloader) doDownloadTask(progressChan chan ProgressChan, task *D buf := make([]byte, 32*1024) for { + select { + case <-fd.ctx.Done(): + return fmt.Errorf("download cancelled") + default: + } + n, err := resp.Body.Read(buf) if n > 0 { writeSize := int64(n) @@ -353,9 +381,9 @@ func (fd *FileDownloader) verifyDownload() error { return nil } -func (fd *FileDownloader) Start() (*FileDownloader, error) { +func (fd *FileDownloader) Start() error { if err := fd.init(); err != nil { - return nil, err + return err } fd.createDownloadTasks() @@ -365,5 +393,19 @@ func (fd *FileDownloader) Start() (*FileDownloader, error) { fd.File.Close() } - return fd, err + return err +} + +func (fd *FileDownloader) Cancel() { + if fd.cancelFunc != nil { + fd.cancelFunc() + } + + if fd.File != nil { + fd.File.Close() + } + + if fd.FileName != "" { + _ = os.Remove(fd.FileName) + } } diff --git a/core/http.go b/core/http.go index d5613bd..e4f03f8 100644 --- a/core/http.go +++ b/core/http.go @@ -345,6 +345,24 @@ func (h *HttpServer) download(w http.ResponseWriter, r *http.Request) { h.success(w) } +func (h *HttpServer) cancel(w http.ResponseWriter, r *http.Request) { + var data struct { + MediaInfo + } + + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.error(w, err.Error()) + return + } + + err := resourceOnce.cancel(data.Id) + if err != nil { + h.error(w, err.Error()) + return + } + h.success(w) +} + func (h *HttpServer) wxFileDecode(w http.ResponseWriter, r *http.Request) { var data struct { MediaInfo diff --git a/core/middleware.go b/core/middleware.go index 6dbcbfd..4aedc1c 100644 --- a/core/middleware.go +++ b/core/middleware.go @@ -56,6 +56,8 @@ func HandleApi(w http.ResponseWriter, r *http.Request) bool { httpServerOnce.delete(w, r) case "/api/download": httpServerOnce.download(w, r) + case "/api/cancel": + httpServerOnce.cancel(w, r) case "/api/wx-file-decode": httpServerOnce.wxFileDecode(w, r) case "/api/batch-export": diff --git a/core/resource.go b/core/resource.go index 47036d6..d6d3b37 100644 --- a/core/resource.go +++ b/core/resource.go @@ -3,6 +3,7 @@ package core import ( "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net/url" @@ -22,6 +23,7 @@ type WxFileDecodeResult struct { type Resource struct { mediaMark sync.Map + tasks sync.Map resType map[string]bool resTypeMux sync.RWMutex } @@ -86,6 +88,15 @@ func (r *Resource) delete(sign string) { r.mediaMark.Delete(sign) } +func (r *Resource) cancel(id string) error { + if d, ok := r.tasks.Load(id); ok { + d.(*FileDownloader).Cancel() + r.tasks.Delete(id) // 可选:取消后清理 + return nil + } + return errors.New("task not found") +} + func (r *Resource) download(mediaInfo MediaInfo, decodeStr string) { if globalConfig.SaveDirectory == "" { return @@ -149,10 +160,13 @@ func (r *Resource) download(mediaInfo MediaInfo, decodeStr string) { downloader.progressCallback = func(totalDownloaded, totalSize float64, taskID int, taskProgress float64) { r.progressEventsEmit(mediaInfo, strconv.Itoa(int(totalDownloaded*100/totalSize))+"%", shared.DownloadStatusRunning) } - fd, err := downloader.Start() - mediaInfo.SavePath = fd.FileName + r.tasks.Store(mediaInfo.Id, downloader) + err := downloader.Start() + mediaInfo.SavePath = downloader.FileName if err != nil { - r.progressEventsEmit(mediaInfo, err.Error()) + if !strings.Contains(err.Error(), "cancelled") { + r.progressEventsEmit(mediaInfo, err.Error()) + } return } if decodeStr != "" { diff --git a/frontend/src/api/app.ts b/frontend/src/api/app.ts index e98581c..2ab17aa 100644 --- a/frontend/src/api/app.ts +++ b/frontend/src/api/app.ts @@ -92,6 +92,13 @@ export default { data: data }) }, + cancel(data: object) { + return request({ + url: 'api/cancel', + method: 'post', + data: data + }) + }, download(data: object) { return request({ url: 'api/download', diff --git a/frontend/src/components/Action.vue b/frontend/src/components/Action.vue index bca5cf4..cb00908 100644 --- a/frontend/src/components/Action.vue +++ b/frontend/src/components/Action.vue @@ -23,6 +23,16 @@