From 0c461991f939a2d2718aeb102cee20eada902344 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Mon, 14 Jul 2025 23:55:17 +0800 Subject: [PATCH] chore: standardize context keys with custom ContextKey type (#697) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: standardize context keys with custom ContextKey type * fix bug * 使用Request.Context --- drivers/github/util.go | 3 +- drivers/local/driver.go | 13 ------- internal/authn/authn.go | 2 +- internal/conf/const.go | 16 +++++++-- internal/fs/archive.go | 7 ++-- internal/fs/copy.go | 2 +- internal/fs/list.go | 5 +-- internal/fs/move.go | 3 +- internal/fs/put.go | 3 +- internal/fs/walk.go | 3 +- internal/net/util.go | 2 -- internal/offline_download/tool/add.go | 2 +- internal/offline_download/tool/download.go | 2 +- internal/offline_download/tool/transfer.go | 5 +-- internal/search/build.go | 2 +- internal/stream/util.go | 5 +-- internal/task/base.go | 2 +- server/common/base.go | 9 ++--- server/common/common.go | 18 ++++++++++ server/common/proxy.go | 3 +- server/debug.go | 11 +++--- server/ftp.go | 18 +++++----- server/ftp/afero.go | 3 +- server/ftp/fsmanage.go | 7 ++-- server/ftp/fsread.go | 27 +++++++------- server/ftp/fsup.go | 4 +-- server/handles/archive.go | 28 +++++++-------- server/handles/auth.go | 9 ++--- server/handles/down.go | 8 ++--- server/handles/fsbatch.go | 27 +++++++------- server/handles/fsmanage.go | 41 +++++++++++----------- server/handles/fsread.go | 28 +++++++-------- server/handles/fsup.go | 17 ++++----- server/handles/offline_download.go | 2 +- server/handles/search.go | 3 +- server/handles/sshkey.go | 7 ++-- server/handles/storage.go | 10 +++--- server/handles/task.go | 5 +-- server/handles/webauthn.go | 8 ++--- server/middlewares/auth.go | 16 ++++----- server/middlewares/check.go | 4 ++- server/middlewares/down.go | 4 +-- server/middlewares/fsup.go | 3 +- server/router.go | 1 + server/s3/backend.go | 11 +++--- server/s3/utils.go | 4 +-- server/sftp.go | 16 ++++----- server/webdav.go | 15 ++++---- server/webdav/file.go | 4 +-- server/webdav/prop.go | 3 +- server/webdav/webdav.go | 21 +++++------ 51 files changed, 253 insertions(+), 219 deletions(-) diff --git a/drivers/github/util.go b/drivers/github/util.go index 0a0d1fbd..2e19986b 100644 --- a/drivers/github/util.go +++ b/drivers/github/util.go @@ -9,6 +9,7 @@ import ( "text/template" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/ProtonMail/go-crypto/openpgp" @@ -96,7 +97,7 @@ func getPathCommonAncestor(a, b string) (ancestor, aChildName, bChildName, aRest } func getUsername(ctx context.Context) string { - user, ok := ctx.Value("user").(*model.User) + user, ok := ctx.Value(conf.UserKey).(*model.User) if !ok { return "" } diff --git a/drivers/local/driver.go b/drivers/local/driver.go index 575f3603..39c571dd 100644 --- a/drivers/local/driver.go +++ b/drivers/local/driver.go @@ -173,19 +173,6 @@ func (d *Local) FileInfoToObj(ctx context.Context, f fs.FileInfo, reqPath string } return &file } -func (d *Local) GetMeta(ctx context.Context, path string) (model.Obj, error) { - f, err := os.Stat(path) - if err != nil { - return nil, err - } - file := d.FileInfoToObj(ctx, f, path, path) - //h := "123123" - //if s, ok := f.(model.SetHash); ok && file.GetHash() == ("","") { - // s.SetHash(h,"SHA1") - //} - return file, nil - -} func (d *Local) Get(ctx context.Context, path string) (model.Obj, error) { path = filepath.Join(d.GetRootPath(), path) diff --git a/internal/authn/authn.go b/internal/authn/authn.go index 21757b23..a57823d1 100644 --- a/internal/authn/authn.go +++ b/internal/authn/authn.go @@ -12,7 +12,7 @@ import ( ) func NewAuthnInstance(c *gin.Context) (*webauthn.WebAuthn, error) { - siteUrl, err := url.Parse(common.GetApiUrl(c)) + siteUrl, err := url.Parse(common.GetApiUrl(c.Request.Context())) if err != nil { return nil, err } diff --git a/internal/conf/const.go b/internal/conf/const.go index 865f60f4..d7fe66b7 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -149,7 +149,19 @@ const ( ) // ContextKey is the type of context keys. +type ContextKey int + const ( - NoTaskKey = "no_task" - ApiUrlKey = "api_url" + _ ContextKey = iota + + NoTaskKey + ApiUrlKey + UserKey + MetaKey + MetaPassKey + ClientIPKey + ProxyHeaderKey + RequestHeaderKey + UserAgentKey + PathKey ) diff --git a/internal/fs/archive.go b/internal/fs/archive.go index ea8c5e84..fb047b82 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -7,7 +7,6 @@ import ( "io" "math/rand" "mime" - "net/http" "os" stdpath "path" "path/filepath" @@ -68,9 +67,7 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT if t.srcStorage == nil { t.srcStorage, err = op.GetStorageByMountPath(t.SrcStorageMp) } - srcObj, tool, ss, err := op.GetArchiveToolAndStream(t.Ctx(), t.srcStorage, t.SrcObjPath, model.LinkArgs{ - Header: http.Header{}, - }) + srcObj, tool, ss, err := op.GetArchiveToolAndStream(t.Ctx(), t.srcStorage, t.SrcObjPath, model.LinkArgs{}) if err != nil { return nil, err } @@ -355,7 +352,7 @@ func archiveDecompress(ctx context.Context, srcObjPath, dstDirPath string, args return nil, err } } - taskCreator, _ := ctx.Value("user").(*model.User) + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) tsk := &ArchiveDownloadTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 64186aa6..53bfc929 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -102,7 +102,7 @@ func _copy(ctx context.Context, srcObjPath, dstDirPath string, lazyCache ...bool } } // not in the same storage - taskCreator, _ := ctx.Value("user").(*model.User) + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) t := &CopyTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, diff --git a/internal/fs/list.go b/internal/fs/list.go index 07a22113..cfc13229 100644 --- a/internal/fs/list.go +++ b/internal/fs/list.go @@ -3,6 +3,7 @@ package fs import ( "context" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/pkg/utils" @@ -12,8 +13,8 @@ import ( // List files func list(ctx context.Context, path string, args *ListArgs) ([]model.Obj, error) { - meta, _ := ctx.Value("meta").(*model.Meta) - user, _ := ctx.Value("user").(*model.User) + meta, _ := ctx.Value(conf.MetaKey).(*model.Meta) + user, _ := ctx.Value(conf.UserKey).(*model.User) virtualFiles := op.GetStorageVirtualFilesByPath(path) storage, actualPath, err := op.GetStorageAndActualPath(path) if err != nil && len(virtualFiles) == 0 { diff --git a/internal/fs/move.go b/internal/fs/move.go index bc9b4ed9..5764c032 100644 --- a/internal/fs/move.go +++ b/internal/fs/move.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -586,7 +587,7 @@ func _moveWithValidation(ctx context.Context, srcObjPath, dstDirPath string, val } } - taskCreator, _ := ctx.Value("user").(*model.User) + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) // Create task immediately without any synchronous checks to avoid blocking frontend // All validation and type checking will be done asynchronously in the Run method diff --git a/internal/fs/put.go b/internal/fs/put.go index bc59c244..c5872de0 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -55,7 +56,7 @@ func putAsTask(ctx context.Context, dstDirPath string, file model.FileStreamer) //file.SetReader(tempFile) //file.SetTmpFile(tempFile) } - taskCreator, _ := ctx.Value("user").(*model.User) // taskCreator is nil when convert failed + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) // taskCreator is nil when convert failed t := &UploadTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, diff --git a/internal/fs/walk.go b/internal/fs/walk.go index 22af8506..a534dc4d 100644 --- a/internal/fs/walk.go +++ b/internal/fs/walk.go @@ -5,6 +5,7 @@ import ( "path" "path/filepath" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" ) @@ -28,7 +29,7 @@ func WalkFS(ctx context.Context, depth int, name string, info model.Obj, walkFn } meta, _ := op.GetNearestMeta(name) // Read directory names. - objs, err := List(context.WithValue(ctx, "meta", meta), name, &ListArgs{}) + objs, err := List(context.WithValue(ctx, conf.MetaKey, meta), name, &ListArgs{}) if err != nil { return walkFnErr } diff --git a/internal/net/util.go b/internal/net/util.go index fc5921ad..40b5e145 100644 --- a/internal/net/util.go +++ b/internal/net/util.go @@ -350,5 +350,3 @@ func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.Rea // return an io.ReadCloser that is limited to `length` bytes. return &LimitedReadCloser{readCloser, length_int}, nil } - -type RequestHeaderKey struct{} diff --git a/internal/offline_download/tool/add.go b/internal/offline_download/tool/add.go index 3b42a050..153d376d 100644 --- a/internal/offline_download/tool/add.go +++ b/internal/offline_download/tool/add.go @@ -122,7 +122,7 @@ func AddURL(ctx context.Context, args *AddURLArgs) (task.TaskExtensionInfo, erro } } - taskCreator, _ := ctx.Value("user").(*model.User) // taskCreator is nil when convert failed + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) // taskCreator is nil when convert failed t := &DownloadTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, diff --git a/internal/offline_download/tool/download.go b/internal/offline_download/tool/download.go index dcc8062e..ce36d189 100644 --- a/internal/offline_download/tool/download.go +++ b/internal/offline_download/tool/download.go @@ -181,7 +181,7 @@ func (t *DownloadTask) Transfer() error { if err != nil { return errors.WithMessage(err, "failed get dst storage") } - taskCreator, _ := t.Ctx().Value("user").(*model.User) + taskCreator, _ := t.Ctx().Value(conf.UserKey).(*model.User) task := &TransferTask{ TaskExtension: task.TaskExtension{ Creator: taskCreator, diff --git a/internal/offline_download/tool/transfer.go b/internal/offline_download/tool/transfer.go index 2264e1d9..fa9b5737 100644 --- a/internal/offline_download/tool/transfer.go +++ b/internal/offline_download/tool/transfer.go @@ -8,6 +8,7 @@ import ( "path/filepath" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" @@ -116,7 +117,7 @@ func transferStd(ctx context.Context, tempDir, dstDirPath string, deletePolicy D if err != nil { return err } - taskCreator, _ := ctx.Value("user").(*model.User) + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) for _, entry := range entries { t := &TransferTask{ TaskExtension: task.TaskExtension{ @@ -216,7 +217,7 @@ func transferObj(ctx context.Context, tempDir, dstDirPath string, deletePolicy D if err != nil { return errors.WithMessagef(err, "failed list src [%s] objs", tempDir) } - taskCreator, _ := ctx.Value("user").(*model.User) // taskCreator is nil when convert failed + taskCreator, _ := ctx.Value(conf.UserKey).(*model.User) // taskCreator is nil when convert failed for _, obj := range objs { t := &TransferTask{ TaskExtension: task.TaskExtension{ diff --git a/internal/search/build.go b/internal/search/build.go index 48b46c8e..c5c74e09 100644 --- a/internal/search/build.go +++ b/internal/search/build.go @@ -179,7 +179,7 @@ func BuildIndex(ctx context.Context, indexPaths, ignorePaths []string, maxDepth return err } // TODO: run walkFS concurrently - err = fs.WalkFS(context.WithValue(ctx, "user", admin), maxDepth, indexPath, fi, walkFn) + err = fs.WalkFS(context.WithValue(ctx, conf.UserKey, admin), maxDepth, indexPath, fi, walkFn) if err != nil { return err } diff --git a/internal/stream/util.go b/internal/stream/util.go index 5971860c..aee5c603 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -8,6 +8,7 @@ import ( "io" "net/http" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" @@ -38,7 +39,7 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, Size: size, } } else { - requestHeader, _ := ctx.Value(net.RequestHeaderKey{}).(http.Header) + requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header) header := net.ProcessHeader(requestHeader, link.Header) req = &net.HttpRequestParams{ Range: httpRange, @@ -67,7 +68,7 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, if httpRange.Length < 0 || httpRange.Start+httpRange.Length > size { httpRange.Length = size - httpRange.Start } - requestHeader, _ := ctx.Value(net.RequestHeaderKey{}).(http.Header) + requestHeader, _ := ctx.Value(conf.RequestHeaderKey).(http.Header) header := net.ProcessHeader(requestHeader, link.Header) header = http_range.ApplyRangeToHttpHeader(httpRange, header) diff --git a/internal/task/base.go b/internal/task/base.go index 3ffcee32..8976ed90 100644 --- a/internal/task/base.go +++ b/internal/task/base.go @@ -20,7 +20,7 @@ type TaskExtension struct { func (t *TaskExtension) SetCtx(ctx context.Context) { if t.Creator != nil { - ctx = context.WithValue(ctx, "user", t.Creator) + ctx = context.WithValue(ctx, conf.UserKey, t.Creator) } if len(t.ApiUrl) > 0 { ctx = context.WithValue(ctx, conf.ApiUrlKey, t.ApiUrl) diff --git a/server/common/base.go b/server/common/base.go index 8dcce11f..8aa669e3 100644 --- a/server/common/base.go +++ b/server/common/base.go @@ -10,7 +10,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" ) -func GetApiUrlFormRequest(r *http.Request) string { +func GetApiUrlFromRequest(r *http.Request) string { api := conf.Conf.SiteURL if strings.HasPrefix(api, "http") { return strings.TrimSuffix(api, "/") @@ -31,9 +31,6 @@ func GetApiUrlFormRequest(r *http.Request) string { } func GetApiUrl(ctx context.Context) string { - val := ctx.Value(conf.ApiUrlKey) - if api, ok := val.(string); ok { - return api - } - return "" + api, _ := ctx.Value(conf.ApiUrlKey).(string) + return api } diff --git a/server/common/common.go b/server/common/common.go index 1ec84403..6a29757b 100644 --- a/server/common/common.go +++ b/server/common/common.go @@ -1,6 +1,7 @@ package common import ( + "context" "strings" "github.com/OpenListTeam/OpenList/v4/cmd/flags" @@ -88,3 +89,20 @@ func Pluralize(count int, singular, plural string) string { } return plural } + +func GinWithValue(c *gin.Context, keyAndValue ...any) { + c.Request = c.Request.WithContext( + ContentWithValue(c.Request.Context(), keyAndValue...), + ) +} + +func ContentWithValue(ctx context.Context, keyAndValue ...any) context.Context { + if len(keyAndValue) < 1 || len(keyAndValue)%2 != 0 { + panic("keyAndValue must be an even number of arguments (key, value, ...)") + } + for len(keyAndValue) > 0 { + ctx = context.WithValue(ctx, keyAndValue[0], keyAndValue[1]) + keyAndValue = keyAndValue[2:] + } + return ctx +} diff --git a/server/common/proxy.go b/server/common/proxy.go index 5f065fff..5cae84fb 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -9,6 +9,7 @@ import ( "maps" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/internal/stream" @@ -26,7 +27,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. attachHeader(w, file, link.Header) rrf, _ := stream.GetRangeReaderFromLink(file.GetSize(), link) if link.RangeReader == nil { - r = r.WithContext(context.WithValue(r.Context(), net.RequestHeaderKey{}, r.Header)) + r = r.WithContext(context.WithValue(r.Context(), conf.RequestHeaderKey, r.Header)) } return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{ RangeReader: rrf, diff --git a/server/debug.go b/server/debug.go index 2e4cbe42..86096b0e 100644 --- a/server/debug.go +++ b/server/debug.go @@ -5,6 +5,7 @@ import ( _ "net/http/pprof" "runtime" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/middlewares" @@ -16,14 +17,14 @@ func _pprof(g *gin.RouterGroup) { } func debug(g *gin.RouterGroup) { - g.GET("/path/*path", middlewares.Down(sign.Verify), func(ctx *gin.Context) { - rawPath := ctx.MustGet("path").(string) - ctx.JSON(200, gin.H{ + g.GET("/path/*path", middlewares.Down(sign.Verify), func(c *gin.Context) { + rawPath := c.Request.Context().Value(conf.PathKey).(string) + c.JSON(200, gin.H{ "path": rawPath, }) }) - g.GET("/hide_privacy", func(ctx *gin.Context) { - common.ErrorStrResp(ctx, "This is ip: 1.1.1.1", 400) + g.GET("/hide_privacy", func(c *gin.Context) { + common.ErrorStrResp(c, "This is ip: 1.1.1.1", 400) }) g.GET("/gc", func(c *gin.Context) { runtime.GC() diff --git a/server/ftp.go b/server/ftp.go index 6a7ce622..40fb3716 100644 --- a/server/ftp.go +++ b/server/ftp.go @@ -24,7 +24,7 @@ import ( type FtpMainDriver struct { settings *ftpserver.Settings - proxyHeader *http.Header + proxyHeader http.Header clients map[uint32]ftpserver.ClientContext shutdownLock sync.RWMutex isShutdown bool @@ -32,8 +32,6 @@ type FtpMainDriver struct { } func NewMainDriver() (*FtpMainDriver, error) { - header := &http.Header{} - header.Add("User-Agent", setting.GetStr(conf.FTPProxyUserAgent)) transferType := ftpserver.TransferTypeASCII if conf.Conf.FTP.DefaultTransferBinary { transferType = ftpserver.TransferTypeBinary @@ -80,7 +78,9 @@ func NewMainDriver() (*FtpMainDriver, error) { ActiveConnectionsCheck: activeConnCheck, PasvConnectionsCheck: pasvConnCheck, }, - proxyHeader: header, + proxyHeader: http.Header{ + "User-Agent": {setting.GetStr(conf.FTPProxyUserAgent)}, + }, clients: make(map[uint32]ftpserver.ClientContext), shutdownLock: sync.RWMutex{}, isShutdown: false, @@ -132,14 +132,14 @@ func (d *FtpMainDriver) AuthUser(cc ftpserver.ClientContext, user, pass string) } ctx := context.Background() - ctx = context.WithValue(ctx, "user", userObj) + ctx = context.WithValue(ctx, conf.UserKey, userObj) if user == "anonymous" || user == "guest" { - ctx = context.WithValue(ctx, "meta_pass", pass) + ctx = context.WithValue(ctx, conf.MetaPassKey, pass) } else { - ctx = context.WithValue(ctx, "meta_pass", "") + ctx = context.WithValue(ctx, conf.MetaPassKey, "") } - ctx = context.WithValue(ctx, "client_ip", cc.RemoteAddr().String()) - ctx = context.WithValue(ctx, "proxy_header", d.proxyHeader) + ctx = context.WithValue(ctx, conf.ClientIPKey, cc.RemoteAddr().String()) + ctx = context.WithValue(ctx, conf.ProxyHeaderKey, d.proxyHeader) return ftp.NewAferoAdapter(ctx), nil } diff --git a/server/ftp/afero.go b/server/ftp/afero.go index faf23f35..f5bfda5f 100644 --- a/server/ftp/afero.go +++ b/server/ftp/afero.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -91,7 +92,7 @@ func (a *AferoAdapter) GetHandle(name string, flags int, offset int64) (ftpserve if (flags & os.O_APPEND) != 0 { return nil, errs.NotSupport } - user := a.ctx.Value("user").(*model.User) + user := a.ctx.Value(conf.UserKey).(*model.User) path, err := user.JoinPath(name) if err != nil { return nil, err diff --git a/server/ftp/fsmanage.go b/server/ftp/fsmanage.go index 82bbccca..b00e779f 100644 --- a/server/ftp/fsmanage.go +++ b/server/ftp/fsmanage.go @@ -5,6 +5,7 @@ import ( "fmt" stdpath "path" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -14,7 +15,7 @@ import ( ) func Mkdir(ctx context.Context, path string) error { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(path) if err != nil { return err @@ -34,7 +35,7 @@ func Mkdir(ctx context.Context, path string) error { } func Remove(ctx context.Context, path string) error { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) if !user.CanRemove() || !user.CanFTPManage() { return errs.PermissionDenied } @@ -46,7 +47,7 @@ func Remove(ctx context.Context, path string) error { } func Rename(ctx context.Context, oldPath, newPath string) error { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) srcPath, err := user.JoinPath(oldPath) if err != nil { return err diff --git a/server/ftp/fsread.go b/server/ftp/fsread.go index c52510ec..61244c01 100644 --- a/server/ftp/fsread.go +++ b/server/ftp/fsread.go @@ -8,6 +8,7 @@ import ( "os" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -24,24 +25,22 @@ type FileDownloadProxy struct { } func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownloadProxy, error) { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) meta, err := op.GetNearestMeta(reqPath) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { return nil, err } } - ctx = context.WithValue(ctx, "meta", meta) - if !common.CanAccess(user, meta, reqPath, ctx.Value("meta_pass").(string)) { + ctx = context.WithValue(ctx, conf.MetaKey, meta) + if !common.CanAccess(user, meta, reqPath, ctx.Value(conf.MetaPassKey).(string)) { return nil, errs.PermissionDenied } // directly use proxy - header := *(ctx.Value("proxy_header").(*http.Header)) - link, obj, err := fs.Link(ctx, reqPath, model.LinkArgs{ - IP: ctx.Value("client_ip").(string), - Header: header, - }) + header, _ := ctx.Value(conf.ProxyHeaderKey).(http.Header) + ip, _ := ctx.Value(conf.ClientIPKey).(string) + link, obj, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: ip, Header: header}) if err != nil { return nil, err } @@ -116,7 +115,7 @@ func (o *OsFileInfoAdapter) Sys() any { } func Stat(ctx context.Context, path string) (os.FileInfo, error) { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(path) if err != nil { return nil, err @@ -127,8 +126,8 @@ func Stat(ctx context.Context, path string) (os.FileInfo, error) { return nil, err } } - ctx = context.WithValue(ctx, "meta", meta) - if !common.CanAccess(user, meta, reqPath, ctx.Value("meta_pass").(string)) { + ctx = context.WithValue(ctx, conf.MetaKey, meta) + if !common.CanAccess(user, meta, reqPath, ctx.Value(conf.MetaPassKey).(string)) { return nil, errs.PermissionDenied } obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{}) @@ -139,7 +138,7 @@ func Stat(ctx context.Context, path string) (os.FileInfo, error) { } func List(ctx context.Context, path string) ([]os.FileInfo, error) { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(path) if err != nil { return nil, err @@ -150,8 +149,8 @@ func List(ctx context.Context, path string) ([]os.FileInfo, error) { return nil, err } } - ctx = context.WithValue(ctx, "meta", meta) - if !common.CanAccess(user, meta, reqPath, ctx.Value("meta_pass").(string)) { + ctx = context.WithValue(ctx, conf.MetaKey, meta) + if !common.CanAccess(user, meta, reqPath, ctx.Value(conf.MetaPassKey).(string)) { return nil, errs.PermissionDenied } objs, err := fs.List(ctx, reqPath, &fs.ListArgs{}) diff --git a/server/ftp/fsup.go b/server/ftp/fsup.go index 30f5dda4..7b6241ff 100644 --- a/server/ftp/fsup.go +++ b/server/ftp/fsup.go @@ -29,14 +29,14 @@ type FileUploadProxy struct { } func uploadAuth(ctx context.Context, path string) error { - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) meta, err := op.GetNearestMeta(stdpath.Dir(path)) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { return err } } - if !(common.CanAccess(user, meta, path, ctx.Value("meta_pass").(string)) && + if !(common.CanAccess(user, meta, path, ctx.Value(conf.MetaPassKey).(string)) && ((user.CanFTPManage() && user.CanWrite()) || common.CanWrite(meta, stdpath.Dir(path)))) { return errs.PermissionDenied } diff --git a/server/handles/archive.go b/server/handles/archive.go index 0db4190e..c10cca24 100644 --- a/server/handles/archive.go +++ b/server/handles/archive.go @@ -77,7 +77,7 @@ func FsArchiveMeta(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanReadArchives() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -94,7 +94,7 @@ func FsArchiveMeta(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -106,7 +106,7 @@ func FsArchiveMeta(c *gin.Context) { }, Password: req.ArchivePass, } - ret, err := fs.ArchiveMeta(c, reqPath, model.ArchiveMetaArgs{ + ret, err := fs.ArchiveMeta(c.Request.Context(), reqPath, model.ArchiveMetaArgs{ ArchiveArgs: archiveArgs, Refresh: req.Refresh, }) @@ -154,7 +154,7 @@ func FsArchiveList(c *gin.Context) { return } req.Validate() - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanReadArchives() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -171,12 +171,12 @@ func FsArchiveList(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return } - objs, err := fs.ArchiveList(c, reqPath, model.ArchiveListArgs{ + objs, err := fs.ArchiveList(c.Request.Context(), reqPath, model.ArchiveListArgs{ ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ @@ -239,7 +239,7 @@ func FsArchiveDecompress(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanDecompress() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -260,7 +260,7 @@ func FsArchiveDecompress(c *gin.Context) { } tasks := make([]task.TaskExtensionInfo, 0, len(srcPaths)) for _, srcPath := range srcPaths { - t, e := fs.ArchiveDecompress(c, srcPath, dstDir, model.ArchiveDecompressArgs{ + t, e := fs.ArchiveDecompress(c.Request.Context(), srcPath, dstDir, model.ArchiveDecompressArgs{ ArchiveInnerArgs: model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ @@ -292,7 +292,7 @@ func FsArchiveDecompress(c *gin.Context) { } func ArchiveDown(c *gin.Context) { - archiveRawPath := c.MustGet("path").(string) + archiveRawPath := c.Request.Context().Value(conf.PathKey).(string) innerPath := utils.FixAndCleanPath(c.Query("inner")) password := c.Query("pass") filename := stdpath.Base(innerPath) @@ -305,7 +305,7 @@ func ArchiveDown(c *gin.Context) { ArchiveProxy(c) return } else { - link, _, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{ + link, _, err := fs.ArchiveDriverExtract(c.Request.Context(), archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ IP: c.ClientIP(), @@ -326,7 +326,7 @@ func ArchiveDown(c *gin.Context) { } func ArchiveProxy(c *gin.Context) { - archiveRawPath := c.MustGet("path").(string) + archiveRawPath := c.Request.Context().Value(conf.PathKey).(string) innerPath := utils.FixAndCleanPath(c.Query("inner")) password := c.Query("pass") filename := stdpath.Base(innerPath) @@ -337,7 +337,7 @@ func ArchiveProxy(c *gin.Context) { } if canProxy(storage, filename) { // TODO: Support external download proxy URL - link, file, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{ + link, file, err := fs.ArchiveDriverExtract(c.Request.Context(), archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ Header: c.Request.Header, @@ -359,10 +359,10 @@ func ArchiveProxy(c *gin.Context) { } func ArchiveInternalExtract(c *gin.Context) { - archiveRawPath := c.MustGet("path").(string) + archiveRawPath := c.Request.Context().Value(conf.PathKey).(string) innerPath := utils.FixAndCleanPath(c.Query("inner")) password := c.Query("pass") - rc, size, err := fs.ArchiveInternalExtract(c, archiveRawPath, model.ArchiveInnerArgs{ + rc, size, err := fs.ArchiveInternalExtract(c.Request.Context(), archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ Header: c.Request.Header, diff --git a/server/handles/auth.go b/server/handles/auth.go index 28d4f488..35776ba6 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "image/png" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/server/common" @@ -87,7 +88,7 @@ type UserResp struct { // CurrentUser get current user by token // if token is empty, return guest user func CurrentUser(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) userResp := UserResp{ User: *user, } @@ -104,7 +105,7 @@ func UpdateCurrent(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() { common.ErrorStrResp(c, "Guest user can not update profile", 403) return @@ -122,7 +123,7 @@ func UpdateCurrent(c *gin.Context) { } func Generate2FA(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() { common.ErrorStrResp(c, "Guest user can not generate 2FA code", 403) return @@ -161,7 +162,7 @@ func Verify2FA(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() { common.ErrorStrResp(c, "Guest user can not generate 2FA code", 403) return diff --git a/server/handles/down.go b/server/handles/down.go index 686e2b2c..df98686b 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -24,7 +24,7 @@ import ( ) func Down(c *gin.Context) { - rawPath := c.MustGet("path").(string) + rawPath := c.Request.Context().Value(conf.PathKey).(string) filename := stdpath.Base(rawPath) storage, err := fs.GetStorage(rawPath, &fs.GetStoragesArgs{}) if err != nil { @@ -35,7 +35,7 @@ func Down(c *gin.Context) { Proxy(c) return } else { - link, _, err := fs.Link(c, rawPath, model.LinkArgs{ + link, _, err := fs.Link(c.Request.Context(), rawPath, model.LinkArgs{ IP: c.ClientIP(), Header: c.Request.Header, Type: c.Query("type"), @@ -50,7 +50,7 @@ func Down(c *gin.Context) { } func Proxy(c *gin.Context) { - rawPath := c.MustGet("path").(string) + rawPath := c.Request.Context().Value(conf.PathKey).(string) filename := stdpath.Base(rawPath) storage, err := fs.GetStorage(rawPath, &fs.GetStoragesArgs{}) if err != nil { @@ -70,7 +70,7 @@ func Proxy(c *gin.Context) { return } } - link, file, err := fs.Link(c, rawPath, model.LinkArgs{ + link, file, err := fs.Link(c.Request.Context(), rawPath, model.LinkArgs{ Header: c.Request.Header, Type: c.Query("type"), }) diff --git a/server/handles/fsbatch.go b/server/handles/fsbatch.go index 8b2a3fb8..84a1a119 100644 --- a/server/handles/fsbatch.go +++ b/server/handles/fsbatch.go @@ -5,6 +5,7 @@ import ( "regexp" "slices" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -28,7 +29,7 @@ func FsRecursiveMove(c *gin.Context) { return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanMove() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -51,9 +52,9 @@ func FsRecursiveMove(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) - rootFiles, err := fs.List(c, srcDir, &fs.ListArgs{}) + rootFiles, err := fs.List(c.Request.Context(), srcDir, &fs.ListArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -61,7 +62,7 @@ func FsRecursiveMove(c *gin.Context) { var existingFileNames []string if req.ConflictPolicy != OVERWRITE { - dstFiles, err := fs.List(c, dstDir, &fs.ListArgs{}) + dstFiles, err := fs.List(c.Request.Context(), dstDir, &fs.ListArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -89,7 +90,7 @@ func FsRecursiveMove(c *gin.Context) { if movingFile.IsDir() { // directory, recursive move subFilePath := movingFileName - subFiles, err := fs.List(c, movingFileName, &fs.ListArgs{Refresh: true}) + subFiles, err := fs.List(c.Request.Context(), movingFileName, &fs.ListArgs{Refresh: true}) if err != nil { common.ErrorResp(c, err, 500) return @@ -123,7 +124,7 @@ func FsRecursiveMove(c *gin.Context) { var count = 0 for i, fileName := range movingFileNames { // move - err := fs.Move(c, fileName, dstDir, len(movingFileNames) > i+1) + err := fs.Move(c.Request.Context(), fileName, dstDir, len(movingFileNames) > i+1) if err != nil { common.ErrorResp(c, err, 500) return @@ -148,7 +149,7 @@ func FsBatchRename(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanRename() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -167,13 +168,13 @@ func FsBatchRename(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) for _, renameObject := range req.RenameObjects { if renameObject.SrcName == "" || renameObject.NewName == "" { continue } filePath := fmt.Sprintf("%s/%s", reqPath, renameObject.SrcName) - if err := fs.Rename(c, filePath, renameObject.NewName); err != nil { + if err := fs.Rename(c.Request.Context(), filePath, renameObject.NewName); err != nil { common.ErrorResp(c, err, 500) return } @@ -193,7 +194,7 @@ func FsRegexRename(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanRename() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -212,7 +213,7 @@ func FsRegexRename(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) srcRegexp, err := regexp.Compile(req.SrcNameRegex) if err != nil { @@ -220,7 +221,7 @@ func FsRegexRename(c *gin.Context) { return } - files, err := fs.List(c, reqPath, &fs.ListArgs{}) + files, err := fs.List(c.Request.Context(), reqPath, &fs.ListArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -231,7 +232,7 @@ func FsRegexRename(c *gin.Context) { if srcRegexp.MatchString(file.GetName()) { filePath := fmt.Sprintf("%s/%s", reqPath, file.GetName()) newFileName := srcRegexp.ReplaceAllString(file.GetName(), req.NewNameRegex) - if err := fs.Rename(c, filePath, newFileName); err != nil { + if err := fs.Rename(c.Request.Context(), filePath, newFileName); err != nil { common.ErrorResp(c, err, 500) return } diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index 865717bb..1b79cdc4 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -4,6 +4,7 @@ import ( "fmt" stdpath "path" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/task" "github.com/OpenListTeam/OpenList/v4/internal/errs" @@ -28,7 +29,7 @@ func FsMkdir(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(req.Path) if err != nil { common.ErrorResp(c, err, 403) @@ -47,7 +48,7 @@ func FsMkdir(c *gin.Context) { return } } - if err := fs.MakeDir(c, reqPath); err != nil { + if err := fs.MakeDir(c.Request.Context(), reqPath); err != nil { common.ErrorResp(c, err, 500) return } @@ -71,7 +72,7 @@ func FsMove(c *gin.Context) { common.ErrorStrResp(c, "Empty file names", 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanMove() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -89,7 +90,7 @@ func FsMove(c *gin.Context) { if !req.Overwrite { for _, name := range req.Names { - if res, _ := fs.Get(c, stdpath.Join(dstDir, name), &fs.GetArgs{NoLog: true}); res != nil { + if res, _ := fs.Get(c.Request.Context(), stdpath.Join(dstDir, name), &fs.GetArgs{NoLog: true}); res != nil { common.ErrorStrResp(c, fmt.Sprintf("file [%s] exists", name), 403) return } @@ -100,7 +101,7 @@ func FsMove(c *gin.Context) { // All validation will be done asynchronously in the background var addedTasks []task.TaskExtensionInfo for i, name := range req.Names { - t, err := fs.MoveWithTaskAndValidation(c, stdpath.Join(srcDir, name), dstDir, !req.Overwrite, len(req.Names) > i+1) + t, err := fs.MoveWithTaskAndValidation(c.Request.Context(), stdpath.Join(srcDir, name), dstDir, !req.Overwrite, len(req.Names) > i+1) if t != nil { addedTasks = append(addedTasks, t) } @@ -133,7 +134,7 @@ func FsCopy(c *gin.Context) { common.ErrorStrResp(c, "Empty file names", 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanCopy() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -151,7 +152,7 @@ func FsCopy(c *gin.Context) { if !req.Overwrite { for _, name := range req.Names { - if res, _ := fs.Get(c, stdpath.Join(dstDir, name), &fs.GetArgs{NoLog: true}); res != nil { + if res, _ := fs.Get(c.Request.Context(), stdpath.Join(dstDir, name), &fs.GetArgs{NoLog: true}); res != nil { common.ErrorStrResp(c, fmt.Sprintf("file [%s] exists", name), 403) return } @@ -162,7 +163,7 @@ func FsCopy(c *gin.Context) { // All validation will be done asynchronously in the background var addedTasks []task.TaskExtensionInfo for i, name := range req.Names { - t, err := fs.Copy(c, stdpath.Join(srcDir, name), dstDir, len(req.Names) > i+1) + t, err := fs.Copy(c.Request.Context(), stdpath.Join(srcDir, name), dstDir, len(req.Names) > i+1) if t != nil { addedTasks = append(addedTasks, t) } @@ -197,7 +198,7 @@ func FsRename(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanRename() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -210,13 +211,13 @@ func FsRename(c *gin.Context) { if !req.Overwrite { dstPath := stdpath.Join(stdpath.Dir(reqPath), req.Name) if dstPath != reqPath { - if res, _ := fs.Get(c, dstPath, &fs.GetArgs{NoLog: true}); res != nil { + if res, _ := fs.Get(c.Request.Context(), dstPath, &fs.GetArgs{NoLog: true}); res != nil { common.ErrorStrResp(c, fmt.Sprintf("file [%s] exists", req.Name), 403) return } } } - if err := fs.Rename(c, reqPath, req.Name); err != nil { + if err := fs.Rename(c.Request.Context(), reqPath, req.Name); err != nil { common.ErrorResp(c, err, 500) return } @@ -238,7 +239,7 @@ func FsRemove(c *gin.Context) { common.ErrorStrResp(c, "Empty file names", 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanRemove() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -249,7 +250,7 @@ func FsRemove(c *gin.Context) { return } for _, name := range req.Names { - err := fs.Remove(c, stdpath.Join(reqDir, name)) + err := fs.Remove(c.Request.Context(), stdpath.Join(reqDir, name)) if err != nil { common.ErrorResp(c, err, 500) return @@ -270,7 +271,7 @@ func FsRemoveEmptyDirectory(c *gin.Context) { return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanRemove() { common.ErrorResp(c, errs.PermissionDenied, 403) return @@ -288,9 +289,9 @@ func FsRemoveEmptyDirectory(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) - rootFiles, err := fs.List(c, srcDir, &fs.ListArgs{}) + rootFiles, err := fs.List(c.Request.Context(), srcDir, &fs.ListArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -321,7 +322,7 @@ func FsRemoveEmptyDirectory(c *gin.Context) { continue } - subFiles, err := fs.List(c, removingFilePath, &fs.ListArgs{Refresh: true}) + subFiles, err := fs.List(c.Request.Context(), removingFilePath, &fs.ListArgs{Refresh: true}) if err != nil { common.ErrorResp(c, err, 500) return @@ -329,7 +330,7 @@ func FsRemoveEmptyDirectory(c *gin.Context) { if len(subFiles) == 0 { // remove empty directory - err = fs.Remove(c, removingFilePath) + err = fs.Remove(c.Request.Context(), removingFilePath) removedFiles[removingFilePath] = true if err != nil { common.ErrorResp(c, err, 500) @@ -365,7 +366,7 @@ func Link(c *gin.Context) { common.ErrorResp(c, err, 400) return } - //user := c.MustGet("user").(*model.User) + //user := c.Request.Context().Value(conf.UserKey).(*model.User) //rawPath := stdpath.Join(user.BasePath, req.Path) // why need not join base_path? because it's always the full path rawPath := req.Path @@ -383,7 +384,7 @@ func Link(c *gin.Context) { }) return } - link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, Redirect: true}) + link, _, err := fs.Link(c.Request.Context(), rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, Redirect: true}) if err != nil { common.ErrorResp(c, err, 500) return diff --git a/server/handles/fsread.go b/server/handles/fsread.go index b5608750..501fbcf4 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -63,7 +63,7 @@ func FsList(c *gin.Context) { return } req.Validate() - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(req.Path) if err != nil { common.ErrorResp(c, err, 403) @@ -76,7 +76,7 @@ func FsList(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -85,7 +85,7 @@ func FsList(c *gin.Context) { common.ErrorStrResp(c, "Refresh without permission", 403) return } - objs, err := fs.List(c, reqPath, &fs.ListArgs{Refresh: req.Refresh}) + objs, err := fs.List(c.Request.Context(), reqPath, &fs.ListArgs{Refresh: req.Refresh}) if err != nil { common.ErrorResp(c, err, 500) return @@ -112,7 +112,7 @@ func FsDirs(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) reqPath := req.Path if req.ForceRoot { if !user.IsAdmin() { @@ -134,12 +134,12 @@ func FsDirs(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return } - objs, err := fs.List(c, reqPath, &fs.ListArgs{}) + objs, err := fs.List(c.Request.Context(), reqPath, &fs.ListArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -249,7 +249,7 @@ func FsGet(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) reqPath, err := user.JoinPath(req.Path) if err != nil { common.ErrorResp(c, err, 403) @@ -262,12 +262,12 @@ func FsGet(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return } - obj, err := fs.Get(c, reqPath, &fs.GetArgs{}) + obj, err := fs.Get(c.Request.Context(), reqPath, &fs.GetArgs{}) if err != nil { common.ErrorResp(c, err, 500) return @@ -306,7 +306,7 @@ func FsGet(c *gin.Context) { rawURL = url } else { // if storage is not proxy, use raw url by fs.Link - link, _, err := fs.Link(c, reqPath, model.LinkArgs{ + link, _, err := fs.Link(c.Request.Context(), reqPath, model.LinkArgs{ IP: c.ClientIP(), Header: c.Request.Header, Redirect: true, @@ -322,7 +322,7 @@ func FsGet(c *gin.Context) { } var related []model.Obj parentPath := stdpath.Dir(reqPath) - sameLevelFiles, err := fs.List(c, parentPath, &fs.ListArgs{}) + sameLevelFiles, err := fs.List(c.Request.Context(), parentPath, &fs.ListArgs{}) if err == nil { related = filterRelated(sameLevelFiles, obj) } @@ -376,7 +376,7 @@ func FsOther(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) var err error req.Path, err = user.JoinPath(req.Path) if err != nil { @@ -390,12 +390,12 @@ func FsOther(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, req.Path, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return } - res, err := fs.Other(c, req.FsOtherArgs) + res, err := fs.Other(c.Request.Context(), req.FsOtherArgs) if err != nil { common.ErrorResp(c, err, 500) return diff --git a/server/handles/fsup.go b/server/handles/fsup.go index ec47ebdd..087a58a9 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -7,6 +7,7 @@ import ( "strconv" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/stream" @@ -42,14 +43,14 @@ func FsStream(c *gin.Context) { } asTask := c.GetHeader("As-Task") == "true" overwrite := c.GetHeader("Overwrite") != "false" - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) path, err = user.JoinPath(path) if err != nil { common.ErrorResp(c, err, 403) return } if !overwrite { - if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { + if res, _ := fs.Get(c.Request.Context(), path, &fs.GetArgs{NoLog: true}); res != nil { common.ErrorStrResp(c, "file exists", 403) return } @@ -91,9 +92,9 @@ func FsStream(c *gin.Context) { } var t task.TaskExtensionInfo if asTask { - t, err = fs.PutAsTask(c, dir, s) + t, err = fs.PutAsTask(c.Request.Context(), dir, s) } else { - err = fs.PutDirectly(c, dir, s, true) + err = fs.PutDirectly(c.Request.Context(), dir, s, true) } if err != nil { common.ErrorResp(c, err, 500) @@ -123,14 +124,14 @@ func FsForm(c *gin.Context) { } asTask := c.GetHeader("As-Task") == "true" overwrite := c.GetHeader("Overwrite") != "false" - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) path, err = user.JoinPath(path) if err != nil { common.ErrorResp(c, err, 403) return } if !overwrite { - if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { + if res, _ := fs.Get(c.Request.Context(), path, &fs.GetArgs{NoLog: true}); res != nil { common.ErrorStrResp(c, "file exists", 403) return } @@ -186,9 +187,9 @@ func FsForm(c *gin.Context) { s.Reader = struct { io.Reader }{f} - t, err = fs.PutAsTask(c, dir, s) + t, err = fs.PutAsTask(c.Request.Context(), dir, s) } else { - err = fs.PutDirectly(c, dir, s, true) + err = fs.PutDirectly(c.Request.Context(), dir, s, true) } if err != nil { common.ErrorResp(c, err, 500) diff --git a/server/handles/offline_download.go b/server/handles/offline_download.go index 5ceccbb6..0ccd5302 100644 --- a/server/handles/offline_download.go +++ b/server/handles/offline_download.go @@ -343,7 +343,7 @@ type AddOfflineDownloadReq struct { } func AddOfflineDownload(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.CanAddOfflineDownloadTasks() { common.ErrorStrResp(c, "permission denied", 403) return diff --git a/server/handles/search.go b/server/handles/search.go index d53c808c..32e4e955 100644 --- a/server/handles/search.go +++ b/server/handles/search.go @@ -4,6 +4,7 @@ import ( "path" "strings" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" @@ -33,7 +34,7 @@ func Search(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) req.Parent, err = user.JoinPath(req.Parent) if err != nil { common.ErrorResp(c, err, 400) diff --git a/server/handles/sshkey.go b/server/handles/sshkey.go index 79116ffe..d4521b9d 100644 --- a/server/handles/sshkey.go +++ b/server/handles/sshkey.go @@ -4,6 +4,7 @@ import ( "strconv" "strings" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/server/common" @@ -16,7 +17,7 @@ type SSHKeyAddReq struct { } func AddMyPublicKey(c *gin.Context) { - userObj, ok := c.Value("user").(*model.User) + userObj, ok := c.Request.Context().Value(conf.UserKey).(*model.User) if !ok || userObj.IsGuest() { common.ErrorStrResp(c, "user invalid", 401) return @@ -47,7 +48,7 @@ func AddMyPublicKey(c *gin.Context) { } func ListMyPublicKey(c *gin.Context) { - userObj, ok := c.Value("user").(*model.User) + userObj, ok := c.Request.Context().Value(conf.UserKey).(*model.User) if !ok || userObj.IsGuest() { common.ErrorStrResp(c, "user invalid", 401) return @@ -56,7 +57,7 @@ func ListMyPublicKey(c *gin.Context) { } func DeleteMyPublicKey(c *gin.Context) { - userObj, ok := c.Value("user").(*model.User) + userObj, ok := c.Request.Context().Value(conf.UserKey).(*model.User) if !ok || userObj.IsGuest() { common.ErrorStrResp(c, "user invalid", 401) return diff --git a/server/handles/storage.go b/server/handles/storage.go index 0f313d46..70b9e1ad 100644 --- a/server/handles/storage.go +++ b/server/handles/storage.go @@ -38,7 +38,7 @@ func CreateStorage(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if id, err := op.CreateStorage(c, req); err != nil { + if id, err := op.CreateStorage(c.Request.Context(), req); err != nil { common.ErrorWithDataResp(c, err, 500, gin.H{ "id": id, }, true) @@ -55,7 +55,7 @@ func UpdateStorage(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := op.UpdateStorage(c, req); err != nil { + if err := op.UpdateStorage(c.Request.Context(), req); err != nil { common.ErrorResp(c, err, 500, true) } else { common.SuccessResp(c) @@ -69,7 +69,7 @@ func DeleteStorage(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := op.DeleteStorageById(c, uint(id)); err != nil { + if err := op.DeleteStorageById(c.Request.Context(), uint(id)); err != nil { common.ErrorResp(c, err, 500, true) return } @@ -83,7 +83,7 @@ func DisableStorage(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := op.DisableStorage(c, uint(id)); err != nil { + if err := op.DisableStorage(c.Request.Context(), uint(id)); err != nil { common.ErrorResp(c, err, 500, true) return } @@ -97,7 +97,7 @@ func EnableStorage(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := op.EnableStorage(c, uint(id)); err != nil { + if err := op.EnableStorage(c.Request.Context(), uint(id)); err != nil { common.ErrorResp(c, err, 500, true) return } diff --git a/server/handles/task.go b/server/handles/task.go index c9ad2190..032f363a 100644 --- a/server/handles/task.go +++ b/server/handles/task.go @@ -4,6 +4,7 @@ import ( "math" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/task" @@ -11,8 +12,8 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/offline_download/tool" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server/common" - "github.com/gin-gonic/gin" "github.com/OpenListTeam/tache" + "github.com/gin-gonic/gin" ) type TaskInfo struct { @@ -69,7 +70,7 @@ func argsContains[T comparable](v T, slice ...T) bool { } func getUserInfo(c *gin.Context) (bool, uint, bool) { - if user, ok := c.Value("user").(*model.User); ok { + if user, ok := c.Request.Context().Value(conf.UserKey).(*model.User); ok { return user.IsAdmin(), user.ID, true } else { return false, 0, false diff --git a/server/handles/webauthn.go b/server/handles/webauthn.go index 9b79e16b..c7ad4edf 100644 --- a/server/handles/webauthn.go +++ b/server/handles/webauthn.go @@ -125,7 +125,7 @@ func BeginAuthnRegistration(c *gin.Context) { common.ErrorStrResp(c, "WebAuthn is not enabled", 403) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) authnInstance, err := authn.NewAuthnInstance(c) if err != nil { @@ -155,7 +155,7 @@ func FinishAuthnRegistration(c *gin.Context) { common.ErrorStrResp(c, "WebAuthn is not enabled", 403) return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) sessionDataString := c.GetHeader("Session") authnInstance, err := authn.NewAuthnInstance(c) @@ -196,7 +196,7 @@ func FinishAuthnRegistration(c *gin.Context) { } func DeleteAuthnLogin(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) type DeleteAuthnReq struct { ID string `json:"id"` } @@ -224,7 +224,7 @@ func GetAuthnCredentials(c *gin.Context) { ID []byte `json:"id"` FingerPrint string `json:"fingerprint"` } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) credentials := user.WebAuthnCredentials() res := make([]WebAuthnCredentials, 0, len(credentials)) for _, v := range credentials { diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 379a5d07..44a93140 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -23,7 +23,7 @@ func Auth(c *gin.Context) { c.Abort() return } - c.Set("user", admin) + common.GinWithValue(c, conf.UserKey, admin) log.Debugf("use admin token: %+v", admin) c.Next() return @@ -40,7 +40,7 @@ func Auth(c *gin.Context) { c.Abort() return } - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) log.Debugf("use empty token: %+v", guest) c.Next() return @@ -68,7 +68,7 @@ func Auth(c *gin.Context) { c.Abort() return } - c.Set("user", user) + common.GinWithValue(c, conf.UserKey, user) log.Debugf("use login token: %+v", user) c.Next() } @@ -82,7 +82,7 @@ func Authn(c *gin.Context) { c.Abort() return } - c.Set("user", admin) + common.GinWithValue(c, conf.UserKey, admin) log.Debugf("use admin token: %+v", admin) c.Next() return @@ -94,7 +94,7 @@ func Authn(c *gin.Context) { c.Abort() return } - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) log.Debugf("use empty token: %+v", guest) c.Next() return @@ -122,13 +122,13 @@ func Authn(c *gin.Context) { c.Abort() return } - c.Set("user", user) + common.GinWithValue(c, conf.UserKey, user) log.Debugf("use login token: %+v", user) c.Next() } func AuthNotGuest(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() { common.ErrorStrResp(c, "You are a guest", 403) c.Abort() @@ -138,7 +138,7 @@ func AuthNotGuest(c *gin.Context) { } func AuthAdmin(c *gin.Context) { - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) if !user.IsAdmin() { common.ErrorStrResp(c, "You are not an admin", 403) c.Abort() diff --git a/server/middlewares/check.go b/server/middlewares/check.go index 09858cba..a1011de3 100644 --- a/server/middlewares/check.go +++ b/server/middlewares/check.go @@ -26,6 +26,8 @@ func StoragesLoaded(c *gin.Context) { c.Abort() return } - c.Set(conf.ApiUrlKey, common.GetApiUrlFormRequest(c.Request)) + common.GinWithValue(c, + conf.ApiUrlKey, common.GetApiUrlFromRequest(c.Request), + ) c.Next() } diff --git a/server/middlewares/down.go b/server/middlewares/down.go index 57512382..ee4815c2 100644 --- a/server/middlewares/down.go +++ b/server/middlewares/down.go @@ -18,7 +18,7 @@ import ( func Down(verifyFunc func(string, string) error) func(c *gin.Context) { return func(c *gin.Context) { rawPath := parsePath(c.Param("path")) - c.Set("path", rawPath) + common.GinWithValue(c, conf.PathKey, rawPath) meta, err := op.GetNearestMeta(rawPath) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { @@ -26,7 +26,7 @@ func Down(verifyFunc func(string, string) error) func(c *gin.Context) { return } } - c.Set("meta", meta) + common.GinWithValue(c, conf.MetaKey, meta) // verify sign if needSign(meta, rawPath) { s := c.Query("sign") diff --git a/server/middlewares/fsup.go b/server/middlewares/fsup.go index 6fd71d53..08b160ee 100644 --- a/server/middlewares/fsup.go +++ b/server/middlewares/fsup.go @@ -4,6 +4,7 @@ import ( "net/url" stdpath "path" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" @@ -21,7 +22,7 @@ func FsUp(c *gin.Context) { c.Abort() return } - user := c.MustGet("user").(*model.User) + user := c.Request.Context().Value(conf.UserKey).(*model.User) path, err = user.JoinPath(path) if err != nil { common.ErrorResp(c, err, 403) diff --git a/server/router.go b/server/router.go index 27594b0d..4cef4106 100644 --- a/server/router.go +++ b/server/router.go @@ -16,6 +16,7 @@ import ( ) func Init(e *gin.Engine) { + e.ContextWithFallback = true if !utils.SliceContains([]string{"", "/"}, conf.URL.Path) { e.GET("/", func(c *gin.Context) { c.Redirect(302, conf.URL.Path) diff --git a/server/s3/backend.go b/server/s3/backend.go index e79d57a9..cc92ab5e 100644 --- a/server/s3/backend.go +++ b/server/s3/backend.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -108,7 +109,7 @@ func (b *s3Backend) HeadObject(ctx context.Context, bucketName, objectName strin fp := path.Join(bucketPath, objectName) fmeta, _ := op.GetNearestMeta(fp) - node, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}) + node, err := fs.Get(context.WithValue(ctx, conf.MetaKey, fmeta), fp, &fs.GetArgs{}) if err != nil { return nil, gofakes3.KeyNotFound(objectName) } @@ -151,7 +152,7 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string fp := path.Join(bucketPath, objectName) fmeta, _ := op.GetNearestMeta(fp) - node, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}) + node, err := fs.Get(context.WithValue(ctx, conf.MetaKey, fmeta), fp, &fs.GetArgs{}) if err != nil { return nil, gofakes3.KeyNotFound(objectName) } @@ -247,7 +248,7 @@ func (b *s3Backend) PutObject( } log.Debugf("reqPath: %s", reqPath) fmeta, _ := op.GetNearestMeta(fp) - ctx = context.WithValue(ctx, "meta", fmeta) + ctx = context.WithValue(ctx, conf.MetaKey, fmeta) _, err = fs.Get(ctx, reqPath, &fs.GetArgs{}) if err != nil { @@ -341,7 +342,7 @@ func (b *s3Backend) deleteObject(ctx context.Context, bucketName, objectName str fmeta, _ := op.GetNearestMeta(fp) // S3 does not report an error when attemping to delete a key that does not exist, so // we need to skip IsNotExist errors. - if _, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}); err != nil && !errs.IsObjectNotFound(err) { + if _, err := fs.Get(context.WithValue(ctx, conf.MetaKey, fmeta), fp, &fs.GetArgs{}); err != nil && !errs.IsObjectNotFound(err) { return err } @@ -388,7 +389,7 @@ func (b *s3Backend) CopyObject(ctx context.Context, srcBucket, srcKey, dstBucket srcFp := path.Join(srcBucketPath, srcKey) fmeta, _ := op.GetNearestMeta(srcFp) - srcNode, err := fs.Get(context.WithValue(ctx, "meta", fmeta), srcFp, &fs.GetArgs{}) + srcNode, err := fs.Get(context.WithValue(ctx, conf.MetaKey, fmeta), srcFp, &fs.GetArgs{}) c, err := b.GetObject(ctx, srcBucket, srcKey, nil) if err != nil { diff --git a/server/s3/utils.go b/server/s3/utils.go index c47b1073..eb3dd568 100644 --- a/server/s3/utils.go +++ b/server/s3/utils.go @@ -45,7 +45,7 @@ func getBucketByName(name string) (Bucket, error) { func getDirEntries(path string) ([]model.Obj, error) { ctx := context.Background() meta, _ := op.GetNearestMeta(path) - fi, err := fs.Get(context.WithValue(ctx, "meta", meta), path, &fs.GetArgs{}) + fi, err := fs.Get(context.WithValue(ctx, conf.MetaKey, meta), path, &fs.GetArgs{}) if errs.IsNotFoundError(err) { return nil, gofakes3.ErrNoSuchKey } else if err != nil { @@ -56,7 +56,7 @@ func getDirEntries(path string) ([]model.Obj, error) { return nil, gofakes3.ErrNoSuchKey } - dirEntries, err := fs.List(context.WithValue(ctx, "meta", meta), path, &fs.ListArgs{}) + dirEntries, err := fs.List(context.WithValue(ctx, conf.MetaKey, meta), path, &fs.ListArgs{}) if err != nil { return nil, err } diff --git a/server/sftp.go b/server/sftp.go index 529bf5ab..0f9d9125 100644 --- a/server/sftp.go +++ b/server/sftp.go @@ -18,16 +18,16 @@ import ( ) type SftpDriver struct { - proxyHeader *http.Header + proxyHeader http.Header config *sftpd.Config } func NewSftpDriver() (*SftpDriver, error) { sftp.InitHostKey() - header := &http.Header{} - header.Add("User-Agent", setting.GetStr(conf.FTPProxyUserAgent)) return &SftpDriver{ - proxyHeader: header, + proxyHeader: http.Header{ + "User-Agent": {setting.GetStr(conf.FTPProxyUserAgent)}, + }, }, nil } @@ -61,10 +61,10 @@ func (d *SftpDriver) GetFileSystem(sc *ssh.ServerConn) (sftpd.FileSystem, error) return nil, err } ctx := context.Background() - ctx = context.WithValue(ctx, "user", userObj) - ctx = context.WithValue(ctx, "meta_pass", "") - ctx = context.WithValue(ctx, "client_ip", sc.RemoteAddr().String()) - ctx = context.WithValue(ctx, "proxy_header", d.proxyHeader) + ctx = context.WithValue(ctx, conf.UserKey, userObj) + ctx = context.WithValue(ctx, conf.MetaPassKey, "") + ctx = context.WithValue(ctx, conf.ClientIPKey, sc.RemoteAddr().String()) + ctx = context.WithValue(ctx, conf.ProxyHeaderKey, d.proxyHeader) return &sftp.DriverAdapter{FtpDriver: ftp.NewAferoAdapter(ctx)}, nil } diff --git a/server/webdav.go b/server/webdav.go index 8d62c84f..b2afe581 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -8,6 +8,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/middlewares" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -44,7 +45,7 @@ func WebDav(dav *gin.RouterGroup) { } func ServeWebDAV(c *gin.Context) { - handler.ServeHTTP(c.Writer, c.Request.WithContext(c)) + handler.ServeHTTP(c.Writer, c.Request) } func WebDAVAuth(c *gin.Context) { @@ -54,7 +55,7 @@ func WebDAVAuth(c *gin.Context) { count, cok := model.LoginCache.Get(ip) if cok && count >= model.DefaultMaxAuthRetries { if c.Request.Method == "OPTIONS" { - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) c.Next() return } @@ -78,13 +79,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } - c.Set("user", admin) + common.GinWithValue(c, conf.UserKey, admin) c.Next() return } } if c.Request.Method == "OPTIONS" { - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) c.Next() return } @@ -96,7 +97,7 @@ func WebDAVAuth(c *gin.Context) { user, err := op.GetUserByName(username) if err != nil || user.ValidateRawPassword(password) != nil { if c.Request.Method == "OPTIONS" { - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) c.Next() return } @@ -109,7 +110,7 @@ func WebDAVAuth(c *gin.Context) { model.LoginCache.Del(ip) if user.Disabled || !user.CanWebdavRead() { if c.Request.Method == "OPTIONS" { - c.Set("user", guest) + common.GinWithValue(c, conf.UserKey, guest) c.Next() return } @@ -142,6 +143,6 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } - c.Set("user", user) + common.GinWithValue(c, conf.UserKey, user) c.Next() } diff --git a/server/webdav/file.go b/server/webdav/file.go index 2e9f57e5..ab341152 100644 --- a/server/webdav/file.go +++ b/server/webdav/file.go @@ -33,7 +33,7 @@ func moveFiles(ctx context.Context, src, dst string, overwrite bool) (status int dstDir := path.Dir(dst) srcName := path.Base(src) dstName := path.Base(dst) - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) if srcDir != dstDir && !user.CanMove() { return http.StatusForbidden, nil } @@ -93,7 +93,7 @@ func walkFS(ctx context.Context, depth int, name string, info model.Obj, walkFn } meta, _ := op.GetNearestMeta(name) // Read directory names. - objs, err := fs.List(context.WithValue(ctx, "meta", meta), name, &fs.ListArgs{}) + objs, err := fs.List(context.WithValue(ctx, conf.MetaKey, meta), name, &fs.ListArgs{}) //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) //if err != nil { // return walkFn(name, info, err) diff --git a/server/webdav/prop.go b/server/webdav/prop.go index cda1410f..79d3e0ca 100644 --- a/server/webdav/prop.go +++ b/server/webdav/prop.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/server/common" ) @@ -391,7 +392,7 @@ func findLastModified(ctx context.Context, ls LockSystem, name string, fi model. return fi.ModTime().UTC().Format(http.TimeFormat), nil } func findCreationDate(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { - userAgent := ctx.Value("userAgent").(string) + userAgent := ctx.Value(conf.UserAgentKey).(string) if strings.Contains(strings.ToLower(userAgent), "microsoft-webdav") { return fi.CreateTime().UTC().Format(http.TimeFormat), nil } diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 15e1ccee..a8e15328 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/internal/stream" @@ -195,7 +196,7 @@ func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request) (status return status, err } ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return 403, err @@ -223,7 +224,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta } // TODO: check locks for read-only access?? ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return http.StatusForbidden, err @@ -288,7 +289,7 @@ func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status i defer release() ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return 403, err @@ -333,7 +334,7 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, // TODO(rost): Support the If-Match, If-None-Match headers? See bradfitz' // comments in http.checkEtag. ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return http.StatusForbidden, err @@ -385,7 +386,7 @@ func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request) (status in defer release() ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return 403, err @@ -449,7 +450,7 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status } ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) src, err = user.JoinPath(src) if err != nil { return 403, err @@ -513,7 +514,7 @@ func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus } ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) token, ld, now, created := "", LockDetails{}, time.Now(), false if li == (lockInfo{}) { // An empty lockInfo means to refresh the lock. @@ -632,8 +633,8 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) (status } ctx := r.Context() userAgent := r.Header.Get("User-Agent") - ctx = context.WithValue(ctx, "userAgent", userAgent) - user := ctx.Value("user").(*model.User) + ctx = context.WithValue(ctx, conf.UserAgentKey, userAgent) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return 403, err @@ -712,7 +713,7 @@ func (h *Handler) handleProppatch(w http.ResponseWriter, r *http.Request) (statu defer release() ctx := r.Context() - user := ctx.Value("user").(*model.User) + user := ctx.Value(conf.UserKey).(*model.User) reqPath, err = user.JoinPath(reqPath) if err != nil { return 403, err