refactor: pass api_url through context (#457)

* refactor: pass `api_url` through context

* 移除 LinkArgs.HttpReq

* pref(alias): 减少不必要下载代理

* 修复bug

* net: 支持1并发 分片下载
This commit is contained in:
j2rong4cn
2025-06-30 15:48:05 +08:00
committed by GitHub
parent f0236522f3
commit 103abc942e
30 changed files with 209 additions and 222 deletions

View File

@ -103,7 +103,12 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, ok := storage.(*Alias); !ok && !args.Redirect { useRawLink := len(common.GetApiUrl(ctx)) == 0 // ftp、s3
if !useRawLink {
_, ok := storage.(*Alias)
useRawLink = !ok && !args.Redirect
}
if useRawLink {
link, _, err := op.Link(ctx, storage, reqActualPath, args) link, _, err := op.Link(ctx, storage, reqActualPath, args)
return link, err return link, err
} }
@ -114,13 +119,10 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs)
if common.ShouldProxy(storage, stdpath.Base(sub)) { if common.ShouldProxy(storage, stdpath.Base(sub)) {
link := &model.Link{ link := &model.Link{
URL: fmt.Sprintf("%s/p%s?sign=%s", URL: fmt.Sprintf("%s/p%s?sign=%s",
common.GetApiUrl(args.HttpReq), common.GetApiUrl(ctx),
utils.EncodePath(reqPath, true), utils.EncodePath(reqPath, true),
sign.Sign(reqPath)), sign.Sign(reqPath)),
} }
if args.HttpReq != nil && d.ProxyRange {
link.RangeReadCloser = common.NoProxyRange
}
return link, nil return link, nil
} }
link, _, err := op.Link(ctx, storage, reqActualPath, args) link, _, err := op.Link(ctx, storage, reqActualPath, args)
@ -201,31 +203,24 @@ func (d *Alias) extract(ctx context.Context, dst, sub string, args model.Archive
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, ok := storage.(driver.ArchiveReader); ok { if _, ok := storage.(driver.ArchiveReader); !ok {
if _, ok := storage.(*Alias); !ok && !args.Redirect { return nil, errs.NotImplement
link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) }
return link, err if args.Redirect && common.ShouldProxy(storage, stdpath.Base(sub)) {
}
_, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) _, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if common.ShouldProxy(storage, stdpath.Base(sub)) { link := &model.Link{
link := &model.Link{ URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s",
URL: fmt.Sprintf("%s/ap%s?inner=%s&pass=%s&sign=%s", common.GetApiUrl(ctx),
common.GetApiUrl(args.HttpReq), utils.EncodePath(reqPath, true),
utils.EncodePath(reqPath, true), utils.EncodePath(args.InnerPath, true),
utils.EncodePath(args.InnerPath, true), url.QueryEscape(args.Password),
url.QueryEscape(args.Password), sign.SignArchive(reqPath)),
sign.SignArchive(reqPath)),
}
if args.HttpReq != nil && d.ProxyRange {
link.RangeReadCloser = common.NoProxyRange
}
return link, nil
} }
link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args) return link, nil
return link, err
} }
return nil, errs.NotImplement link, _, err := op.DriverExtract(ctx, storage, reqActualPath, args)
return link, err
} }

View File

@ -163,7 +163,7 @@ func (d *Crypt) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([
if d.Thumbnail && thumb == "" { if d.Thumbnail && thumb == "" {
thumbPath := stdpath.Join(args.ReqPath, ".thumbnails", name+".webp") thumbPath := stdpath.Join(args.ReqPath, ".thumbnails", name+".webp")
thumb = fmt.Sprintf("%s/d%s?sign=%s", thumb = fmt.Sprintf("%s/d%s?sign=%s",
common.GetApiUrl(common.GetHttpReq(ctx)), common.GetApiUrl(ctx),
utils.EncodePath(thumbPath, true), utils.EncodePath(thumbPath, true),
sign.Sign(thumbPath)) sign.Sign(thumbPath))
} }

View File

@ -139,7 +139,7 @@ func (d *Local) FileInfoToObj(ctx context.Context, f fs.FileInfo, reqPath string
if d.Thumbnail { if d.Thumbnail {
typeName := utils.GetFileType(f.Name()) typeName := utils.GetFileType(f.Name())
if typeName == conf.IMAGE || typeName == conf.VIDEO { if typeName == conf.IMAGE || typeName == conf.VIDEO {
thumb = common.GetApiUrl(common.GetHttpReq(ctx)) + stdpath.Join("/d", reqPath, f.Name()) thumb = common.GetApiUrl(ctx) + stdpath.Join("/d", reqPath, f.Name())
thumb = utils.EncodePath(thumb, true) thumb = utils.EncodePath(thumb, true)
thumb += "?type=thumb&sign=" + sign.Sign(stdpath.Join(reqPath, f.Name())) thumb += "?type=thumb&sign=" + sign.Sign(stdpath.Join(reqPath, f.Name()))
} }

View File

@ -76,7 +76,7 @@ func (d *NeteaseMusic) Link(ctx context.Context, file model.Obj, args model.Link
if args.Type == "parsed" { if args.Type == "parsed" {
return lrc.getLyricLink(), nil return lrc.getLyricLink(), nil
} else { } else {
return lrc.getProxyLink(args), nil return lrc.getProxyLink(ctx), nil
} }
} }

View File

@ -48,8 +48,8 @@ type LyricObj struct {
lyric string lyric string
} }
func (lrc *LyricObj) getProxyLink(args model.LinkArgs) *model.Link { func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link {
rawURL := common.GetApiUrl(args.HttpReq) + "/p" + lrc.Path rawURL := common.GetApiUrl(ctx) + "/p" + lrc.Path
rawURL = utils.EncodePath(rawURL, true) + "?type=parsed&sign=" + sign.Sign(lrc.Path) rawURL = utils.EncodePath(rawURL, true) + "?type=parsed&sign=" + sign.Sign(lrc.Path)
return &model.Link{URL: rawURL} return &model.Link{URL: rawURL}
} }

View File

@ -2,17 +2,17 @@ package authn
import ( import (
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/conf"
"github.com/OpenListTeam/OpenList/internal/setting" "github.com/OpenListTeam/OpenList/internal/setting"
"github.com/OpenListTeam/OpenList/server/common" "github.com/OpenListTeam/OpenList/server/common"
"github.com/gin-gonic/gin"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
) )
func NewAuthnInstance(r *http.Request) (*webauthn.WebAuthn, error) { func NewAuthnInstance(c *gin.Context) (*webauthn.WebAuthn, error) {
siteUrl, err := url.Parse(common.GetApiUrl(r)) siteUrl, err := url.Parse(common.GetApiUrl(c))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -148,4 +148,5 @@ const (
// ContextKey is the type of context keys. // ContextKey is the type of context keys.
const ( const (
NoTaskKey = "no_task" NoTaskKey = "no_task"
ApiUrlKey = "api_url"
) )

View File

@ -49,7 +49,9 @@ func (t *ArchiveDownloadTask) GetStatus() string {
} }
func (t *ArchiveDownloadTask) Run() error { func (t *ArchiveDownloadTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()
@ -152,7 +154,9 @@ func (t *ArchiveContentUploadTask) GetStatus() string {
} }
func (t *ArchiveContentUploadTask) Run() error { func (t *ArchiveContentUploadTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()

View File

@ -7,15 +7,15 @@ import (
stdpath "path" stdpath "path"
"time" "time"
"github.com/OpenListTeam/OpenList/internal/errs"
"github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/conf"
"github.com/OpenListTeam/OpenList/internal/driver" "github.com/OpenListTeam/OpenList/internal/driver"
"github.com/OpenListTeam/OpenList/internal/errs"
"github.com/OpenListTeam/OpenList/internal/model" "github.com/OpenListTeam/OpenList/internal/model"
"github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/internal/op"
"github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/internal/stream"
"github.com/OpenListTeam/OpenList/internal/task" "github.com/OpenListTeam/OpenList/internal/task"
"github.com/OpenListTeam/OpenList/pkg/utils" "github.com/OpenListTeam/OpenList/pkg/utils"
"github.com/OpenListTeam/OpenList/server/common"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xhofe/tache" "github.com/xhofe/tache"
) )
@ -40,7 +40,9 @@ func (t *CopyTask) GetStatus() string {
} }
func (t *CopyTask) Run() error { func (t *CopyTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()
@ -107,6 +109,7 @@ func _copy(ctx context.Context, srcObjPath, dstDirPath string, lazyCache ...bool
t := &CopyTask{ t := &CopyTask{
TaskExtension: task.TaskExtension{ TaskExtension: task.TaskExtension{
Creator: taskCreator, Creator: taskCreator,
ApiUrl: common.GetApiUrl(ctx),
}, },
srcStorage: srcStorage, srcStorage: srcStorage,
dstStorage: dstStorage, dstStorage: dstStorage,
@ -140,6 +143,7 @@ func copyBetween2Storages(t *CopyTask, srcStorage, dstStorage driver.Driver, src
CopyTaskManager.Add(&CopyTask{ CopyTaskManager.Add(&CopyTask{
TaskExtension: task.TaskExtension{ TaskExtension: task.TaskExtension{
Creator: t.GetCreator(), Creator: t.GetCreator(),
ApiUrl: t.ApiUrl,
}, },
srcStorage: srcStorage, srcStorage: srcStorage,
dstStorage: dstStorage, dstStorage: dstStorage,

View File

@ -7,7 +7,6 @@ import (
"github.com/OpenListTeam/OpenList/internal/model" "github.com/OpenListTeam/OpenList/internal/model"
"github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/internal/op"
"github.com/OpenListTeam/OpenList/server/common" "github.com/OpenListTeam/OpenList/server/common"
"github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -21,9 +20,7 @@ func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, m
return nil, nil, errors.WithMessage(err, "failed link") return nil, nil, errors.WithMessage(err, "failed link")
} }
if l.URL != "" && !strings.HasPrefix(l.URL, "http://") && !strings.HasPrefix(l.URL, "https://") { if l.URL != "" && !strings.HasPrefix(l.URL, "http://") && !strings.HasPrefix(l.URL, "https://") {
if c, ok := ctx.(*gin.Context); ok { l.URL = common.GetApiUrl(ctx) + l.URL
l.URL = common.GetApiUrl(c.Request) + l.URL
}
} }
return l, obj, nil return l, obj, nil
} }

View File

@ -15,26 +15,27 @@ import (
"github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/internal/stream"
"github.com/OpenListTeam/OpenList/internal/task" "github.com/OpenListTeam/OpenList/internal/task"
"github.com/OpenListTeam/OpenList/pkg/utils" "github.com/OpenListTeam/OpenList/pkg/utils"
"github.com/OpenListTeam/OpenList/server/common"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xhofe/tache" "github.com/xhofe/tache"
) )
type MoveTask struct { type MoveTask struct {
task.TaskExtension task.TaskExtension
Status string `json:"-"` Status string `json:"-"`
SrcObjPath string `json:"src_path"` SrcObjPath string `json:"src_path"`
DstDirPath string `json:"dst_path"` DstDirPath string `json:"dst_path"`
srcStorage driver.Driver `json:"-"` srcStorage driver.Driver `json:"-"`
dstStorage driver.Driver `json:"-"` dstStorage driver.Driver `json:"-"`
SrcStorageMp string `json:"src_storage_mp"` SrcStorageMp string `json:"src_storage_mp"`
DstStorageMp string `json:"dst_storage_mp"` DstStorageMp string `json:"dst_storage_mp"`
IsRootTask bool `json:"is_root_task"` IsRootTask bool `json:"is_root_task"`
RootTaskID string `json:"root_task_id"` RootTaskID string `json:"root_task_id"`
TotalFiles int `json:"total_files"` TotalFiles int `json:"total_files"`
CompletedFiles int `json:"completed_files"` CompletedFiles int `json:"completed_files"`
Phase string `json:"phase"` // "copying", "verifying", "deleting", "completed" Phase string `json:"phase"` // "copying", "verifying", "deleting", "completed"
ValidateExistence bool `json:"validate_existence"` ValidateExistence bool `json:"validate_existence"`
mu sync.RWMutex `json:"-"` mu sync.RWMutex `json:"-"`
} }
type MoveProgress struct { type MoveProgress struct {
@ -62,11 +63,11 @@ func (t *MoveTask) GetStatus() string {
func (t *MoveTask) GetProgress() float64 { func (t *MoveTask) GetProgress() float64 {
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() defer t.mu.RUnlock()
if t.TotalFiles == 0 { if t.TotalFiles == 0 {
return 0 return 0
} }
switch t.Phase { switch t.Phase {
case "copying": case "copying":
return float64(t.CompletedFiles*60) / float64(t.TotalFiles) return float64(t.CompletedFiles*60) / float64(t.TotalFiles)
@ -84,9 +85,9 @@ func (t *MoveTask) GetProgress() float64 {
func (t *MoveTask) GetMoveProgress() *MoveProgress { func (t *MoveTask) GetMoveProgress() *MoveProgress {
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() defer t.mu.RUnlock()
progress := int(t.GetProgress()) progress := int(t.GetProgress())
return &MoveProgress{ return &MoveProgress{
TaskID: t.GetID(), TaskID: t.GetID(),
Phase: t.Phase, Phase: t.Phase,
@ -106,16 +107,18 @@ func (t *MoveTask) updateProgress() {
} }
func (t *MoveTask) Run() error { func (t *MoveTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { defer func() {
t.SetEndTime(time.Now()) t.SetEndTime(time.Now())
if t.IsRootTask { if t.IsRootTask {
moveProgressMap.Delete(t.GetID()) moveProgressMap.Delete(t.GetID())
} }
}() }()
var err error var err error
if t.srcStorage == nil { if t.srcStorage == nil {
t.srcStorage, err = op.GetStorageByMountPath(t.SrcStorageMp) t.srcStorage, err = op.GetStorageByMountPath(t.SrcStorageMp)
@ -131,13 +134,13 @@ func (t *MoveTask) Run() error {
t.mu.Lock() t.mu.Lock()
t.Status = "validating source and destination" t.Status = "validating source and destination"
t.mu.Unlock() t.mu.Unlock()
// Check if source exists // Check if source exists
srcObj, err := op.Get(t.Ctx(), t.srcStorage, t.SrcObjPath) srcObj, err := op.Get(t.Ctx(), t.srcStorage, t.SrcObjPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "source file [%s] not found", stdpath.Base(t.SrcObjPath)) return errors.WithMessagef(err, "source file [%s] not found", stdpath.Base(t.SrcObjPath))
} }
// Check if destination already exists (if validation is required) // Check if destination already exists (if validation is required)
if t.ValidateExistence { if t.ValidateExistence {
dstFilePath := stdpath.Join(t.DstDirPath, srcObj.GetName()) dstFilePath := stdpath.Join(t.DstDirPath, srcObj.GetName())
@ -155,7 +158,7 @@ func (t *MoveTask) Run() error {
t.mu.Unlock() t.mu.Unlock()
return t.runRootMoveTask() return t.runRootMoveTask()
} }
// Use safe move logic for files // Use safe move logic for files
return t.safeMoveOperation(srcObj) return t.safeMoveOperation(srcObj)
} }
@ -167,7 +170,7 @@ func (t *MoveTask) runRootMoveTask() error {
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] object", t.SrcObjPath) return errors.WithMessagef(err, "failed get src [%s] object", t.SrcObjPath)
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
// Source is not a directory, use regular move logic // Source is not a directory, use regular move logic
t.mu.Lock() t.mu.Lock()
@ -175,32 +178,32 @@ func (t *MoveTask) runRootMoveTask() error {
t.mu.Unlock() t.mu.Unlock()
return t.safeMoveOperation(srcObj) return t.safeMoveOperation(srcObj)
} }
// Phase 1: Count total files and create directory structure // Phase 1: Count total files and create directory structure
t.mu.Lock() t.mu.Lock()
t.Phase = "preparing" t.Phase = "preparing"
t.Status = "counting files and preparing directory structure" t.Status = "counting files and preparing directory structure"
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
totalFiles, err := t.countFilesAndCreateDirs(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) totalFiles, err := t.countFilesAndCreateDirs(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath)
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to prepare directory structure") return errors.WithMessage(err, "failed to prepare directory structure")
} }
t.mu.Lock() t.mu.Lock()
t.TotalFiles = totalFiles t.TotalFiles = totalFiles
t.Phase = "copying" t.Phase = "copying"
t.Status = "copying files" t.Status = "copying files"
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
// Phase 2: Copy all files // Phase 2: Copy all files
err = t.copyAllFiles(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) err = t.copyAllFiles(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath)
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to copy files") return errors.WithMessage(err, "failed to copy files")
} }
// Phase 3: Verify directory structure // Phase 3: Verify directory structure
t.mu.Lock() t.mu.Lock()
t.Phase = "verifying" t.Phase = "verifying"
@ -208,12 +211,12 @@ func (t *MoveTask) runRootMoveTask() error {
t.CompletedFiles = 0 t.CompletedFiles = 0
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
err = t.verifyDirectoryStructure(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) err = t.verifyDirectoryStructure(t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath)
if err != nil { if err != nil {
return errors.WithMessage(err, "verification failed") return errors.WithMessage(err, "verification failed")
} }
// Phase 4: Delete source files and directories // Phase 4: Delete source files and directories
t.mu.Lock() t.mu.Lock()
t.Phase = "deleting" t.Phase = "deleting"
@ -221,18 +224,18 @@ func (t *MoveTask) runRootMoveTask() error {
t.CompletedFiles = 0 t.CompletedFiles = 0
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
err = t.deleteSourceRecursively(t.srcStorage, t.SrcObjPath) err = t.deleteSourceRecursively(t.srcStorage, t.SrcObjPath)
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to delete source files") return errors.WithMessage(err, "failed to delete source files")
} }
t.mu.Lock() t.mu.Lock()
t.Phase = "completed" t.Phase = "completed"
t.Status = "completed" t.Status = "completed"
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
return nil return nil
} }
@ -257,11 +260,11 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver,
if err != nil { if err != nil {
return 0, errors.WithMessagef(err, "failed get src [%s] object", srcPath) return 0, errors.WithMessagef(err, "failed get src [%s] object", srcPath)
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
return 1, nil return 1, nil
} }
// Create destination directory // Create destination directory
dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) dstObjPath := stdpath.Join(dstPath, srcObj.GetName())
err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath) err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath)
@ -271,13 +274,13 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver,
} }
return 0, errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath) return 0, errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath)
} }
// List and count files recursively // List and count files recursively
objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{})
if err != nil { if err != nil {
return 0, errors.WithMessagef(err, "failed list src [%s] objs", srcPath) return 0, errors.WithMessagef(err, "failed list src [%s] objs", srcPath)
} }
totalFiles := 0 totalFiles := 0
for _, obj := range objs { for _, obj := range objs {
if utils.IsCanceled(t.Ctx()) { if utils.IsCanceled(t.Ctx()) {
@ -290,7 +293,7 @@ func (t *MoveTask) countFilesAndCreateDirs(srcStorage, dstStorage driver.Driver,
} }
totalFiles += subCount totalFiles += subCount
} }
return totalFiles, nil return totalFiles, nil
} }
@ -300,27 +303,27 @@ func (t *MoveTask) copyAllFiles(srcStorage, dstStorage driver.Driver, srcPath, d
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] object", srcPath) return errors.WithMessagef(err, "failed get src [%s] object", srcPath)
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
// Copy single file // Copy single file
err := t.copyFile(srcStorage, dstStorage, srcPath, dstPath) err := t.copyFile(srcStorage, dstStorage, srcPath, dstPath)
if err != nil { if err != nil {
return err return err
} }
t.mu.Lock() t.mu.Lock()
t.CompletedFiles++ t.CompletedFiles++
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
return nil return nil
} }
// Copy directory contents // Copy directory contents
objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed list src [%s] objs", srcPath) return errors.WithMessagef(err, "failed list src [%s] objs", srcPath)
} }
dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) dstObjPath := stdpath.Join(dstPath, srcObj.GetName())
for _, obj := range objs { for _, obj := range objs {
if utils.IsCanceled(t.Ctx()) { if utils.IsCanceled(t.Ctx()) {
@ -332,7 +335,7 @@ func (t *MoveTask) copyAllFiles(srcStorage, dstStorage driver.Driver, srcPath, d
return err return err
} }
} }
return nil return nil
} }
@ -342,24 +345,24 @@ func (t *MoveTask) copyFile(srcStorage, dstStorage driver.Driver, srcFilePath, d
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath)
} }
link, _, err := op.Link(t.Ctx(), srcStorage, srcFilePath, model.LinkArgs{ link, _, err := op.Link(t.Ctx(), srcStorage, srcFilePath, model.LinkArgs{
Header: http.Header{}, Header: http.Header{},
}) })
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
} }
fs := stream.FileStream{ fs := stream.FileStream{
Obj: srcFile, Obj: srcFile,
Ctx: t.Ctx(), Ctx: t.Ctx(),
} }
ss, err := stream.NewSeekableStream(fs, link) ss, err := stream.NewSeekableStream(fs, link)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath)
} }
return op.Put(t.Ctx(), dstStorage, dstDirPath, ss, nil, true) return op.Put(t.Ctx(), dstStorage, dstDirPath, ss, nil, true)
} }
@ -369,7 +372,7 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] object", srcPath) return errors.WithMessagef(err, "failed get src [%s] object", srcPath)
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
// Verify single file // Verify single file
dstFilePath := stdpath.Join(dstPath, srcObj.GetName()) dstFilePath := stdpath.Join(dstPath, srcObj.GetName())
@ -377,27 +380,27 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver
if err != nil { if err != nil {
return errors.WithMessagef(err, "verification failed: destination file [%s] not found", dstFilePath) return errors.WithMessagef(err, "verification failed: destination file [%s] not found", dstFilePath)
} }
t.mu.Lock() t.mu.Lock()
t.CompletedFiles++ t.CompletedFiles++
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
return nil return nil
} }
// Verify directory // Verify directory
dstObjPath := stdpath.Join(dstPath, srcObj.GetName()) dstObjPath := stdpath.Join(dstPath, srcObj.GetName())
_, err = op.Get(t.Ctx(), dstStorage, dstObjPath) _, err = op.Get(t.Ctx(), dstStorage, dstObjPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "verification failed: destination directory [%s] not found", dstObjPath) return errors.WithMessagef(err, "verification failed: destination directory [%s] not found", dstObjPath)
} }
// Verify directory contents // Verify directory contents
srcObjs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) srcObjs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed list src [%s] objs for verification", srcPath) return errors.WithMessagef(err, "failed list src [%s] objs for verification", srcPath)
} }
for _, obj := range srcObjs { for _, obj := range srcObjs {
if utils.IsCanceled(t.Ctx()) { if utils.IsCanceled(t.Ctx()) {
return nil return nil
@ -408,7 +411,7 @@ func (t *MoveTask) verifyDirectoryStructure(srcStorage, dstStorage driver.Driver
return err return err
} }
} }
return nil return nil
} }
@ -418,27 +421,27 @@ func (t *MoveTask) deleteSourceRecursively(srcStorage driver.Driver, srcPath str
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] object for deletion", srcPath) return errors.WithMessagef(err, "failed get src [%s] object for deletion", srcPath)
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
// Delete single file // Delete single file
err := op.Remove(t.Ctx(), srcStorage, srcPath) err := op.Remove(t.Ctx(), srcStorage, srcPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to delete src [%s] file", srcPath) return errors.WithMessagef(err, "failed to delete src [%s] file", srcPath)
} }
t.mu.Lock() t.mu.Lock()
t.CompletedFiles++ t.CompletedFiles++
t.mu.Unlock() t.mu.Unlock()
t.updateProgress() t.updateProgress()
return nil return nil
} }
// Delete directory contents first // Delete directory contents first
objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{}) objs, err := op.List(t.Ctx(), srcStorage, srcPath, model.ListArgs{})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed list src [%s] objs for deletion", srcPath) return errors.WithMessagef(err, "failed list src [%s] objs for deletion", srcPath)
} }
for _, obj := range objs { for _, obj := range objs {
if utils.IsCanceled(t.Ctx()) { if utils.IsCanceled(t.Ctx()) {
return nil return nil
@ -449,13 +452,13 @@ func (t *MoveTask) deleteSourceRecursively(srcStorage driver.Driver, srcPath str
return err return err
} }
} }
// Delete the directory itself // Delete the directory itself
err = op.Remove(t.Ctx(), srcStorage, srcPath) err = op.Remove(t.Ctx(), srcStorage, srcPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to delete src [%s] directory", srcPath) return errors.WithMessagef(err, "failed to delete src [%s] directory", srcPath)
} }
return nil return nil
} }
@ -465,14 +468,14 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcObjPath) return errors.WithMessagef(err, "failed get src [%s] file", srcObjPath)
} }
if srcObj.IsDir() { if srcObj.IsDir() {
t.Status = "src object is dir, listing objs" t.Status = "src object is dir, listing objs"
objs, err := op.List(t.Ctx(), srcStorage, srcObjPath, model.ListArgs{}) objs, err := op.List(t.Ctx(), srcStorage, srcObjPath, model.ListArgs{})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed list src [%s] objs", srcObjPath) return errors.WithMessagef(err, "failed list src [%s] objs", srcObjPath)
} }
dstObjPath := stdpath.Join(dstDirPath, srcObj.GetName()) dstObjPath := stdpath.Join(dstDirPath, srcObj.GetName())
t.Status = "creating destination directory" t.Status = "creating destination directory"
err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath) err = op.MakeDir(t.Ctx(), dstStorage, dstObjPath)
@ -483,7 +486,7 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src
} }
return errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath) return errors.WithMessagef(err, "failed to create destination directory [%s] in storage [%s]", dstObjPath, dstStorage.GetStorage().MountPath)
} }
for _, obj := range objs { for _, obj := range objs {
if utils.IsCanceled(t.Ctx()) { if utils.IsCanceled(t.Ctx()) {
return nil return nil
@ -492,6 +495,7 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src
MoveTaskManager.Add(&MoveTask{ MoveTaskManager.Add(&MoveTask{
TaskExtension: task.TaskExtension{ TaskExtension: task.TaskExtension{
Creator: t.GetCreator(), Creator: t.GetCreator(),
ApiUrl: t.ApiUrl,
}, },
srcStorage: srcStorage, srcStorage: srcStorage,
dstStorage: dstStorage, dstStorage: dstStorage,
@ -515,13 +519,13 @@ func moveBetween2Storages(t *MoveTask, srcStorage, dstStorage driver.Driver, src
} }
} }
func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Driver, srcFilePath, dstDirPath string) error { func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Driver, srcFilePath, dstDirPath string) error {
tsk.Status = "copying file to destination" tsk.Status = "copying file to destination"
copyTask := &CopyTask{ copyTask := &CopyTask{
TaskExtension: task.TaskExtension{ TaskExtension: task.TaskExtension{
Creator: tsk.GetCreator(), Creator: tsk.GetCreator(),
ApiUrl: tsk.ApiUrl,
}, },
srcStorage: srcStorage, srcStorage: srcStorage,
dstStorage: dstStorage, dstStorage: dstStorage,
@ -530,10 +534,8 @@ func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Drive
SrcStorageMp: srcStorage.GetStorage().MountPath, SrcStorageMp: srcStorage.GetStorage().MountPath,
DstStorageMp: dstStorage.GetStorage().MountPath, DstStorageMp: dstStorage.GetStorage().MountPath,
} }
copyTask.SetCtx(tsk.Ctx()) copyTask.SetCtx(tsk.Ctx())
err := copyBetween2Storages(copyTask, srcStorage, dstStorage, srcFilePath, dstDirPath) err := copyBetween2Storages(copyTask, srcStorage, dstStorage, srcFilePath, dstDirPath)
if err != nil { if err != nil {
@ -543,21 +545,20 @@ func moveFileBetween2Storages(tsk *MoveTask, srcStorage, dstStorage driver.Drive
} }
return errors.WithMessagef(err, "failed to copy [%s] to destination storage [%s]", srcFilePath, dstStorage.GetStorage().MountPath) return errors.WithMessagef(err, "failed to copy [%s] to destination storage [%s]", srcFilePath, dstStorage.GetStorage().MountPath)
} }
tsk.SetProgress(50) tsk.SetProgress(50)
tsk.Status = "deleting source file" tsk.Status = "deleting source file"
err = op.Remove(tsk.Ctx(), srcStorage, srcFilePath) err = op.Remove(tsk.Ctx(), srcStorage, srcFilePath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to delete src [%s] file from storage [%s] after successful copy", srcFilePath, srcStorage.GetStorage().MountPath) return errors.WithMessagef(err, "failed to delete src [%s] file from storage [%s] after successful copy", srcFilePath, srcStorage.GetStorage().MountPath)
} }
tsk.SetProgress(100) tsk.SetProgress(100)
tsk.Status = "completed" tsk.Status = "completed"
return nil return nil
} }
// safeMoveOperation ensures copy-then-delete sequence for safe move operations // safeMoveOperation ensures copy-then-delete sequence for safe move operations
func (t *MoveTask) safeMoveOperation(srcObj model.Obj) error { func (t *MoveTask) safeMoveOperation(srcObj model.Obj) error {
if srcObj.IsDir() { if srcObj.IsDir() {
@ -592,12 +593,13 @@ func _moveWithValidation(ctx context.Context, srcObjPath, dstDirPath string, val
} }
taskCreator, _ := ctx.Value("user").(*model.User) taskCreator, _ := ctx.Value("user").(*model.User)
// Create task immediately without any synchronous checks to avoid blocking frontend // Create task immediately without any synchronous checks to avoid blocking frontend
// All validation and type checking will be done asynchronously in the Run method // All validation and type checking will be done asynchronously in the Run method
t := &MoveTask{ t := &MoveTask{
TaskExtension: task.TaskExtension{ TaskExtension: task.TaskExtension{
Creator: taskCreator, Creator: taskCreator,
ApiUrl: common.GetApiUrl(ctx),
}, },
srcStorage: srcStorage, srcStorage: srcStorage,
dstStorage: dstStorage, dstStorage: dstStorage,
@ -608,7 +610,7 @@ func _moveWithValidation(ctx context.Context, srcObjPath, dstDirPath string, val
ValidateExistence: validateExistence, ValidateExistence: validateExistence,
Phase: "initializing", Phase: "initializing",
} }
MoveTaskManager.Add(t) MoveTaskManager.Add(t)
return t, nil return t, nil
} }

View File

@ -20,7 +20,6 @@ type LinkArgs struct {
IP string IP string
Header http.Header Header http.Header
Type string Type string
HttpReq *http.Request
Redirect bool Redirect bool
} }

View File

@ -171,7 +171,7 @@ func (d *downloader) download() (io.ReadCloser, error) {
log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency) log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency)
if d.cfg.Concurrency == 1 { if maxPart == 1 {
if d.cfg.ConcurrencyLimit != nil { if d.cfg.ConcurrencyLimit != nil {
go func() { go func() {
<-d.ctx.Done() <-d.ctx.Done()

View File

@ -28,7 +28,9 @@ type DownloadTask struct {
} }
func (t *DownloadTask) Run() error { func (t *DownloadTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()

View File

@ -33,7 +33,9 @@ type TransferTask struct {
} }
func (t *TransferTask) Run() error { func (t *TransferTask) Run() error {
t.ReinitCtx() if err := t.ReinitCtx(); err != nil {
return err
}
t.ClearEndTime() t.ClearEndTime()
t.SetStartTime(time.Now()) t.SetStartTime(time.Now())
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()

View File

@ -19,7 +19,7 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl
return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link") return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link")
} }
rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) { rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) {
if link.Concurrency != 0 || link.PartSize != 0 { if link.Concurrency > 0 || link.PartSize > 0 {
header := net.ProcessHeader(nil, link.Header) header := net.ProcessHeader(nil, link.Header)
down := net.NewDownloader(func(d *net.Downloader) { down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency d.Concurrency = link.Concurrency

View File

@ -2,7 +2,6 @@ package task
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/conf"
@ -12,12 +11,21 @@ import (
type TaskExtension struct { type TaskExtension struct {
tache.Base tache.Base
ctx context.Context Creator *model.User
ctxInitMutex sync.Mutex startTime *time.Time
Creator *model.User endTime *time.Time
startTime *time.Time totalBytes int64
endTime *time.Time ApiUrl string
totalBytes int64 }
func (t *TaskExtension) SetCtx(ctx context.Context) {
if t.Creator != nil {
ctx = context.WithValue(ctx, "user", t.Creator)
}
if len(t.ApiUrl) > 0 {
ctx = context.WithValue(ctx, conf.ApiUrlKey, t.ApiUrl)
}
t.Base.SetCtx(ctx)
} }
func (t *TaskExtension) SetCreator(creator *model.User) { func (t *TaskExtension) SetCreator(creator *model.User) {
@ -57,29 +65,18 @@ func (t *TaskExtension) GetTotalBytes() int64 {
return t.totalBytes return t.totalBytes
} }
func (t *TaskExtension) Ctx() context.Context { func (t *TaskExtension) ReinitCtx() error {
if t.ctx == nil {
t.ctxInitMutex.Lock()
if t.ctx == nil {
t.ctx = context.WithValue(t.Base.Ctx(), "user", t.Creator)
}
t.ctxInitMutex.Unlock()
}
return t.ctx
}
func (t *TaskExtension) ReinitCtx() {
if !conf.Conf.Tasks.AllowRetryCanceled {
return
}
select { select {
case <-t.Base.Ctx().Done(): case <-t.Ctx().Done():
if !conf.Conf.Tasks.AllowRetryCanceled {
return t.Ctx().Err()
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
t.SetCtx(ctx) t.SetCtx(ctx)
t.SetCancelFunc(cancel) t.SetCancelFunc(cancel)
t.ctx = nil
default: default:
} }
return nil
} }
type TaskExtensionInfo interface { type TaskExtensionInfo interface {

View File

@ -1,6 +1,8 @@
package task package task
import "github.com/xhofe/tache" import (
"github.com/xhofe/tache"
)
type Manager[T tache.Task] interface { type Manager[T tache.Task] interface {
Add(task T) Add(task T)

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
stdpath "path" stdpath "path"
@ -9,7 +10,7 @@ import (
"github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/conf"
) )
func GetApiUrl(r *http.Request) string { func GetApiUrlFormRequest(r *http.Request) string {
api := conf.Conf.SiteURL api := conf.Conf.SiteURL
if strings.HasPrefix(api, "http") { if strings.HasPrefix(api, "http") {
return strings.TrimSuffix(api, "/") return strings.TrimSuffix(api, "/")
@ -28,3 +29,11 @@ func GetApiUrl(r *http.Request) string {
api = strings.TrimSuffix(api, "/") api = strings.TrimSuffix(api, "/")
return api return api
} }
func GetApiUrl(ctx context.Context) string {
val := ctx.Value(conf.ApiUrlKey)
if api, ok := val.(string); ok {
return api
}
return ""
}

View File

@ -1,8 +1,6 @@
package common package common
import ( import (
"context"
"net/http"
"strings" "strings"
"github.com/OpenListTeam/OpenList/cmd/flags" "github.com/OpenListTeam/OpenList/cmd/flags"
@ -90,10 +88,3 @@ func Pluralize(count int, singular, plural string) string {
} }
return plural return plural
} }
func GetHttpReq(ctx context.Context) *http.Request {
if c, ok := ctx.(*gin.Context); ok {
return c.Request
}
return nil
}

View File

@ -15,7 +15,6 @@ import (
"github.com/OpenListTeam/OpenList/internal/stream" "github.com/OpenListTeam/OpenList/internal/stream"
"github.com/OpenListTeam/OpenList/pkg/http_range" "github.com/OpenListTeam/OpenList/pkg/http_range"
"github.com/OpenListTeam/OpenList/pkg/utils" "github.com/OpenListTeam/OpenList/pkg/utils"
log "github.com/sirupsen/logrus"
) )
func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error {
@ -42,7 +41,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
RangeReadCloserIF: link.RangeReadCloser, RangeReadCloserIF: link.RangeReadCloser,
Limiter: stream.ServerDownloadLimit, Limiter: stream.ServerDownloadLimit,
}) })
} else if link.Concurrency != 0 || link.PartSize != 0 { } else if link.Concurrency > 0 || link.PartSize > 0 {
attachHeader(w, file) attachHeader(w, file)
size := file.GetSize() size := file.GetSize()
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
@ -110,21 +109,16 @@ func GetEtag(file model.Obj) string {
return fmt.Sprintf(`"%x-%x"`, file.ModTime().Unix(), file.GetSize()) return fmt.Sprintf(`"%x-%x"`, file.ModTime().Unix(), file.GetSize())
} }
var NoProxyRange = &model.RangeReadCloser{} func ProxyRange(ctx context.Context, link *model.Link, size int64) {
func ProxyRange(link *model.Link, size int64) {
if link.MFile != nil { if link.MFile != nil {
return return
} }
if link.RangeReadCloser == nil { if link.RangeReadCloser == nil && !strings.HasPrefix(link.URL, GetApiUrl(ctx)+"/") {
var rrc, err = stream.GetRangeReadCloserFromLink(size, link) var rrc, err = stream.GetRangeReadCloserFromLink(size, link)
if err != nil { if err != nil {
log.Warnf("ProxyRange error: %s", err)
return return
} }
link.RangeReadCloser = rrc link.RangeReadCloser = rrc
} else if link.RangeReadCloser == NoProxyRange {
link.RangeReadCloser = nil
} }
} }

View File

@ -101,9 +101,8 @@ func FsArchiveMeta(c *gin.Context) {
} }
archiveArgs := model.ArchiveArgs{ archiveArgs := model.ArchiveArgs{
LinkArgs: model.LinkArgs{ LinkArgs: model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}, },
Password: req.ArchivePass, Password: req.ArchivePass,
} }
@ -132,7 +131,7 @@ func FsArchiveMeta(c *gin.Context) {
IsEncrypted: ret.IsEncrypted(), IsEncrypted: ret.IsEncrypted(),
Content: toContentResp(ret.GetTree()), Content: toContentResp(ret.GetTree()),
Sort: ret.Sort, Sort: ret.Sort,
RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c.Request), api, utils.EncodePath(reqPath, true)), RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c), api, utils.EncodePath(reqPath, true)),
Sign: s, Sign: s,
}) })
} }
@ -181,9 +180,8 @@ func FsArchiveList(c *gin.Context) {
ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveInnerArgs: model.ArchiveInnerArgs{
ArchiveArgs: model.ArchiveArgs{ ArchiveArgs: model.ArchiveArgs{
LinkArgs: model.LinkArgs{ LinkArgs: model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}, },
Password: req.ArchivePass, Password: req.ArchivePass,
}, },
@ -266,9 +264,8 @@ func FsArchiveDecompress(c *gin.Context) {
ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveInnerArgs: model.ArchiveInnerArgs{
ArchiveArgs: model.ArchiveArgs{ ArchiveArgs: model.ArchiveArgs{
LinkArgs: model.LinkArgs{ LinkArgs: model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}, },
Password: req.ArchivePass, Password: req.ArchivePass,
}, },
@ -314,7 +311,6 @@ func ArchiveDown(c *gin.Context) {
IP: c.ClientIP(), IP: c.ClientIP(),
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
Redirect: true, Redirect: true,
}, },
Password: password, Password: password,
@ -344,9 +340,8 @@ func ArchiveProxy(c *gin.Context) {
link, file, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{ link, file, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{
ArchiveArgs: model.ArchiveArgs{ ArchiveArgs: model.ArchiveArgs{
LinkArgs: model.LinkArgs{ LinkArgs: model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}, },
Password: password, Password: password,
}, },
@ -370,9 +365,8 @@ func ArchiveInternalExtract(c *gin.Context) {
rc, size, err := fs.ArchiveInternalExtract(c, archiveRawPath, model.ArchiveInnerArgs{ rc, size, err := fs.ArchiveInternalExtract(c, archiveRawPath, model.ArchiveInnerArgs{
ArchiveArgs: model.ArchiveArgs{ ArchiveArgs: model.ArchiveArgs{
LinkArgs: model.LinkArgs{ LinkArgs: model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}, },
Password: password, Password: password,
}, },

View File

@ -38,7 +38,6 @@ func Down(c *gin.Context) {
IP: c.ClientIP(), IP: c.ClientIP(),
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
Redirect: true, Redirect: true,
}) })
if err != nil { if err != nil {
@ -71,9 +70,8 @@ func Proxy(c *gin.Context) {
} }
} }
link, file, err := fs.Link(c, rawPath, model.LinkArgs{ link, file, err := fs.Link(c, rawPath, model.LinkArgs{
Header: c.Request.Header, Header: c.Request.Header,
Type: c.Query("type"), Type: c.Query("type"),
HttpReq: c.Request,
}) })
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
@ -126,7 +124,7 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo
} }
} }
if proxyRange { if proxyRange {
common.ProxyRange(link, file.GetSize()) common.ProxyRange(c, link, file.GetSize())
} }
Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer} Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer}

View File

@ -97,7 +97,7 @@ func FsMove(c *gin.Context) {
} }
} }
} }
// Create all tasks immediately without any synchronous validation // Create all tasks immediately without any synchronous validation
// All validation will be done asynchronously in the background // All validation will be done asynchronously in the background
var addedTasks []task.TaskExtensionInfo var addedTasks []task.TaskExtensionInfo
@ -111,12 +111,12 @@ func FsMove(c *gin.Context) {
return return
} }
} }
// Return immediately with task information // Return immediately with task information
if len(addedTasks) > 0 { if len(addedTasks) > 0 {
common.SuccessResp(c, gin.H{ common.SuccessResp(c, gin.H{
"message": fmt.Sprintf("Successfully created %d move task(s)", len(addedTasks)), "message": fmt.Sprintf("Successfully created %d move task(s)", len(addedTasks)),
"tasks": getTaskInfos(addedTasks), "tasks": getTaskInfos(addedTasks),
}) })
} else { } else {
common.SuccessResp(c, gin.H{ common.SuccessResp(c, gin.H{
@ -159,7 +159,7 @@ func FsCopy(c *gin.Context) {
} }
} }
} }
// Create all tasks immediately without any synchronous validation // Create all tasks immediately without any synchronous validation
// All validation will be done asynchronously in the background // All validation will be done asynchronously in the background
var addedTasks []task.TaskExtensionInfo var addedTasks []task.TaskExtensionInfo
@ -173,12 +173,12 @@ func FsCopy(c *gin.Context) {
return return
} }
} }
// Return immediately with task information // Return immediately with task information
if len(addedTasks) > 0 { if len(addedTasks) > 0 {
common.SuccessResp(c, gin.H{ common.SuccessResp(c, gin.H{
"message": fmt.Sprintf("Successfully created %d copy task(s)", len(addedTasks)), "message": fmt.Sprintf("Successfully created %d copy task(s)", len(addedTasks)),
"tasks": getTaskInfos(addedTasks), "tasks": getTaskInfos(addedTasks),
}) })
} else { } else {
common.SuccessResp(c, gin.H{ common.SuccessResp(c, gin.H{
@ -379,13 +379,13 @@ func Link(c *gin.Context) {
if storage.Config().OnlyLocal { if storage.Config().OnlyLocal {
common.SuccessResp(c, model.Link{ common.SuccessResp(c, model.Link{
URL: fmt.Sprintf("%s/p%s?d&sign=%s", URL: fmt.Sprintf("%s/p%s?d&sign=%s",
common.GetApiUrl(c.Request), common.GetApiUrl(c),
utils.EncodePath(rawPath, true), utils.EncodePath(rawPath, true),
sign.Sign(rawPath)), sign.Sign(rawPath)),
}) })
return return
} }
link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, HttpReq: c.Request}) link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header})
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return

View File

@ -296,7 +296,7 @@ func FsGet(c *gin.Context) {
sign.Sign(reqPath)) sign.Sign(reqPath))
} else { } else {
rawURL = fmt.Sprintf("%s/p%s%s", rawURL = fmt.Sprintf("%s/p%s%s",
common.GetApiUrl(c.Request), common.GetApiUrl(c),
utils.EncodePath(reqPath, true), utils.EncodePath(reqPath, true),
query) query)
} }
@ -309,7 +309,6 @@ func FsGet(c *gin.Context) {
link, _, err := fs.Link(c, reqPath, model.LinkArgs{ link, _, err := fs.Link(c, reqPath, model.LinkArgs{
IP: c.ClientIP(), IP: c.ClientIP(),
Header: c.Request.Header, Header: c.Request.Header,
HttpReq: c.Request,
Redirect: true, Redirect: true,
}) })
if err != nil { if err != nil {

View File

@ -48,9 +48,9 @@ func verifyState(clientID, ip, state string) bool {
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string { func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
if useCompatibility { if useCompatibility {
return common.GetApiUrl(c.Request) + "/api/auth/" + method return common.GetApiUrl(c) + "/api/auth/" + method
} else { } else {
return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method return common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + method
} }
} }
@ -236,7 +236,7 @@ func OIDCLoginCallback(c *gin.Context) {
} }
if method == "get_sso_id" { if method == "get_sso_id" {
if useCompatibility { if useCompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID)
return return
} }
html := fmt.Sprintf(`<!DOCTYPE html> html := fmt.Sprintf(`<!DOCTYPE html>
@ -263,7 +263,7 @@ func OIDCLoginCallback(c *gin.Context) {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
} }
if useCompatibility { if useCompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token)
return return
} }
html := fmt.Sprintf(`<!DOCTYPE html> html := fmt.Sprintf(`<!DOCTYPE html>
@ -364,9 +364,9 @@ func SSOLoginCallback(c *gin.Context) {
} else { } else {
var redirect_uri string var redirect_uri string
if usecompatibility { if usecompatibility {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument redirect_uri = common.GetApiUrl(c) + "/api/auth/" + argument
} else { } else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument redirect_uri = common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + argument
} }
resp, err = ssoClient.R().SetHeader("Accept", "application/json"). resp, err = ssoClient.R().SetHeader("Accept", "application/json").
SetFormData(map[string]string{ SetFormData(map[string]string{
@ -401,7 +401,7 @@ func SSOLoginCallback(c *gin.Context) {
} }
if argument == "get_sso_id" { if argument == "get_sso_id" {
if usecompatibility { if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID)
return return
} }
html := fmt.Sprintf(`<!DOCTYPE html> html := fmt.Sprintf(`<!DOCTYPE html>
@ -429,7 +429,7 @@ func SSOLoginCallback(c *gin.Context) {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
} }
if usecompatibility { if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token)
return return
} }
html := fmt.Sprintf(`<!DOCTYPE html> html := fmt.Sprintf(`<!DOCTYPE html>

View File

@ -24,7 +24,7 @@ func BeginAuthnLogin(c *gin.Context) {
common.ErrorStrResp(c, "WebAuthn is not enabled", 403) common.ErrorStrResp(c, "WebAuthn is not enabled", 403)
return return
} }
authnInstance, err := authn.NewAuthnInstance(c.Request) authnInstance, err := authn.NewAuthnInstance(c)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
@ -65,7 +65,7 @@ func FinishAuthnLogin(c *gin.Context) {
common.ErrorStrResp(c, "WebAuthn is not enabled", 403) common.ErrorStrResp(c, "WebAuthn is not enabled", 403)
return return
} }
authnInstance, err := authn.NewAuthnInstance(c.Request) authnInstance, err := authn.NewAuthnInstance(c)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
@ -127,7 +127,7 @@ func BeginAuthnRegistration(c *gin.Context) {
} }
user := c.MustGet("user").(*model.User) user := c.MustGet("user").(*model.User)
authnInstance, err := authn.NewAuthnInstance(c.Request) authnInstance, err := authn.NewAuthnInstance(c)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
} }
@ -158,7 +158,7 @@ func FinishAuthnRegistration(c *gin.Context) {
user := c.MustGet("user").(*model.User) user := c.MustGet("user").(*model.User)
sessionDataString := c.GetHeader("Session") sessionDataString := c.GetHeader("Session")
authnInstance, err := authn.NewAuthnInstance(c.Request) authnInstance, err := authn.NewAuthnInstance(c)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return

View File

@ -10,9 +10,7 @@ import (
) )
func StoragesLoaded(c *gin.Context) { func StoragesLoaded(c *gin.Context) {
if conf.StoragesLoaded { if !conf.StoragesLoaded {
c.Next()
} else {
if utils.SliceContains([]string{"", "/", "/favicon.ico"}, c.Request.URL.Path) { if utils.SliceContains([]string{"", "/", "/favicon.ico"}, c.Request.URL.Path) {
c.Next() c.Next()
return return
@ -26,5 +24,8 @@ func StoragesLoaded(c *gin.Context) {
} }
common.ErrorStrResp(c, "Loading storage, please wait", 500) common.ErrorStrResp(c, "Loading storage, please wait", 500)
c.Abort() c.Abort()
return
} }
c.Set(conf.ApiUrlKey, common.GetApiUrlFormRequest(c.Request))
c.Next()
} }

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"crypto/subtle" "crypto/subtle"
"net/http" "net/http"
"path" "path"
@ -11,7 +10,6 @@ import (
"github.com/OpenListTeam/OpenList/server/middlewares" "github.com/OpenListTeam/OpenList/server/middlewares"
"github.com/OpenListTeam/OpenList/internal/conf" "github.com/OpenListTeam/OpenList/internal/conf"
"github.com/OpenListTeam/OpenList/internal/model"
"github.com/OpenListTeam/OpenList/internal/op" "github.com/OpenListTeam/OpenList/internal/op"
"github.com/OpenListTeam/OpenList/internal/setting" "github.com/OpenListTeam/OpenList/internal/setting"
"github.com/OpenListTeam/OpenList/server/webdav" "github.com/OpenListTeam/OpenList/server/webdav"
@ -45,9 +43,7 @@ func WebDav(dav *gin.RouterGroup) {
} }
func ServeWebDAV(c *gin.Context) { func ServeWebDAV(c *gin.Context) {
user := c.MustGet("user").(*model.User) handler.ServeHTTP(c.Writer, c.Request.WithContext(c))
ctx := context.WithValue(c.Request.Context(), "user", user)
handler.ServeHTTP(c.Writer, c.Request.WithContext(ctx))
} }
func WebDAVAuth(c *gin.Context) { func WebDAVAuth(c *gin.Context) {

View File

@ -241,12 +241,12 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
storage, _ := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) storage, _ := fs.GetStorage(reqPath, &fs.GetStoragesArgs{})
downProxyUrl := storage.GetStorage().DownProxyUrl downProxyUrl := storage.GetStorage().DownProxyUrl
if storage.GetStorage().WebdavNative() || (storage.GetStorage().WebdavProxy() && downProxyUrl == "") { if storage.GetStorage().WebdavNative() || (storage.GetStorage().WebdavProxy() && downProxyUrl == "") {
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header, HttpReq: r}) link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header})
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
if storage.GetStorage().ProxyRange { if storage.GetStorage().ProxyRange {
common.ProxyRange(link, fi.GetSize()) common.ProxyRange(ctx, link, fi.GetSize())
} }
err = common.Proxy(w, r, link, fi) err = common.Proxy(w, r, link, fi)
if err != nil { if err != nil {
@ -260,7 +260,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
http.Redirect(w, r, u, http.StatusFound) http.Redirect(w, r, u, http.StatusFound)
} else { } else {
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r, Redirect: true}) link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, Redirect: true})
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }