mirror of
https://github.com/OpenListTeam/OpenList.git
synced 2025-09-19 12:16:24 +08:00
refactor: pass api_url
through context (#457)
* refactor: pass `api_url` through context * 移除 LinkArgs.HttpReq * pref(alias): 减少不必要下载代理 * 修复bug * net: 支持1并发 分片下载
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
stdpath "path"
|
||||
@ -9,7 +10,7 @@ import (
|
||||
"github.com/OpenListTeam/OpenList/internal/conf"
|
||||
)
|
||||
|
||||
func GetApiUrl(r *http.Request) string {
|
||||
func GetApiUrlFormRequest(r *http.Request) string {
|
||||
api := conf.Conf.SiteURL
|
||||
if strings.HasPrefix(api, "http") {
|
||||
return strings.TrimSuffix(api, "/")
|
||||
@ -28,3 +29,11 @@ func GetApiUrl(r *http.Request) string {
|
||||
api = strings.TrimSuffix(api, "/")
|
||||
return api
|
||||
}
|
||||
|
||||
func GetApiUrl(ctx context.Context) string {
|
||||
val := ctx.Value(conf.ApiUrlKey)
|
||||
if api, ok := val.(string); ok {
|
||||
return api
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/OpenListTeam/OpenList/cmd/flags"
|
||||
@ -90,10 +88,3 @@ func Pluralize(count int, singular, plural string) string {
|
||||
}
|
||||
return plural
|
||||
}
|
||||
|
||||
func GetHttpReq(ctx context.Context) *http.Request {
|
||||
if c, ok := ctx.(*gin.Context); ok {
|
||||
return c.Request
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"github.com/OpenListTeam/OpenList/internal/stream"
|
||||
"github.com/OpenListTeam/OpenList/pkg/http_range"
|
||||
"github.com/OpenListTeam/OpenList/pkg/utils"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error {
|
||||
@ -42,7 +41,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
|
||||
RangeReadCloserIF: link.RangeReadCloser,
|
||||
Limiter: stream.ServerDownloadLimit,
|
||||
})
|
||||
} else if link.Concurrency != 0 || link.PartSize != 0 {
|
||||
} else if link.Concurrency > 0 || link.PartSize > 0 {
|
||||
attachHeader(w, file)
|
||||
size := file.GetSize()
|
||||
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||
@ -110,21 +109,16 @@ func GetEtag(file model.Obj) string {
|
||||
return fmt.Sprintf(`"%x-%x"`, file.ModTime().Unix(), file.GetSize())
|
||||
}
|
||||
|
||||
var NoProxyRange = &model.RangeReadCloser{}
|
||||
|
||||
func ProxyRange(link *model.Link, size int64) {
|
||||
func ProxyRange(ctx context.Context, link *model.Link, size int64) {
|
||||
if link.MFile != nil {
|
||||
return
|
||||
}
|
||||
if link.RangeReadCloser == nil {
|
||||
if link.RangeReadCloser == nil && !strings.HasPrefix(link.URL, GetApiUrl(ctx)+"/") {
|
||||
var rrc, err = stream.GetRangeReadCloserFromLink(size, link)
|
||||
if err != nil {
|
||||
log.Warnf("ProxyRange error: %s", err)
|
||||
return
|
||||
}
|
||||
link.RangeReadCloser = rrc
|
||||
} else if link.RangeReadCloser == NoProxyRange {
|
||||
link.RangeReadCloser = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,9 +101,8 @@ func FsArchiveMeta(c *gin.Context) {
|
||||
}
|
||||
archiveArgs := model.ArchiveArgs{
|
||||
LinkArgs: model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
},
|
||||
Password: req.ArchivePass,
|
||||
}
|
||||
@ -132,7 +131,7 @@ func FsArchiveMeta(c *gin.Context) {
|
||||
IsEncrypted: ret.IsEncrypted(),
|
||||
Content: toContentResp(ret.GetTree()),
|
||||
Sort: ret.Sort,
|
||||
RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c.Request), api, utils.EncodePath(reqPath, true)),
|
||||
RawURL: fmt.Sprintf("%s%s%s", common.GetApiUrl(c), api, utils.EncodePath(reqPath, true)),
|
||||
Sign: s,
|
||||
})
|
||||
}
|
||||
@ -181,9 +180,8 @@ func FsArchiveList(c *gin.Context) {
|
||||
ArchiveInnerArgs: model.ArchiveInnerArgs{
|
||||
ArchiveArgs: model.ArchiveArgs{
|
||||
LinkArgs: model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
},
|
||||
Password: req.ArchivePass,
|
||||
},
|
||||
@ -266,9 +264,8 @@ func FsArchiveDecompress(c *gin.Context) {
|
||||
ArchiveInnerArgs: model.ArchiveInnerArgs{
|
||||
ArchiveArgs: model.ArchiveArgs{
|
||||
LinkArgs: model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
},
|
||||
Password: req.ArchivePass,
|
||||
},
|
||||
@ -314,7 +311,6 @@ func ArchiveDown(c *gin.Context) {
|
||||
IP: c.ClientIP(),
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Redirect: true,
|
||||
},
|
||||
Password: password,
|
||||
@ -344,9 +340,8 @@ func ArchiveProxy(c *gin.Context) {
|
||||
link, file, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{
|
||||
ArchiveArgs: model.ArchiveArgs{
|
||||
LinkArgs: model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
},
|
||||
Password: password,
|
||||
},
|
||||
@ -370,9 +365,8 @@ func ArchiveInternalExtract(c *gin.Context) {
|
||||
rc, size, err := fs.ArchiveInternalExtract(c, archiveRawPath, model.ArchiveInnerArgs{
|
||||
ArchiveArgs: model.ArchiveArgs{
|
||||
LinkArgs: model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
},
|
||||
Password: password,
|
||||
},
|
||||
|
@ -38,7 +38,6 @@ func Down(c *gin.Context) {
|
||||
IP: c.ClientIP(),
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Redirect: true,
|
||||
})
|
||||
if err != nil {
|
||||
@ -71,9 +70,8 @@ func Proxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
link, file, err := fs.Link(c, rawPath, model.LinkArgs{
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
HttpReq: c.Request,
|
||||
Header: c.Request.Header,
|
||||
Type: c.Query("type"),
|
||||
})
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 500)
|
||||
@ -126,7 +124,7 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo
|
||||
}
|
||||
}
|
||||
if proxyRange {
|
||||
common.ProxyRange(link, file.GetSize())
|
||||
common.ProxyRange(c, link, file.GetSize())
|
||||
}
|
||||
Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer}
|
||||
|
||||
|
@ -97,7 +97,7 @@ func FsMove(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Create all tasks immediately without any synchronous validation
|
||||
// All validation will be done asynchronously in the background
|
||||
var addedTasks []task.TaskExtensionInfo
|
||||
@ -111,12 +111,12 @@ func FsMove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Return immediately with task information
|
||||
if len(addedTasks) > 0 {
|
||||
common.SuccessResp(c, gin.H{
|
||||
"message": fmt.Sprintf("Successfully created %d move task(s)", len(addedTasks)),
|
||||
"tasks": getTaskInfos(addedTasks),
|
||||
"tasks": getTaskInfos(addedTasks),
|
||||
})
|
||||
} else {
|
||||
common.SuccessResp(c, gin.H{
|
||||
@ -159,7 +159,7 @@ func FsCopy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Create all tasks immediately without any synchronous validation
|
||||
// All validation will be done asynchronously in the background
|
||||
var addedTasks []task.TaskExtensionInfo
|
||||
@ -173,12 +173,12 @@ func FsCopy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Return immediately with task information
|
||||
if len(addedTasks) > 0 {
|
||||
common.SuccessResp(c, gin.H{
|
||||
"message": fmt.Sprintf("Successfully created %d copy task(s)", len(addedTasks)),
|
||||
"tasks": getTaskInfos(addedTasks),
|
||||
"tasks": getTaskInfos(addedTasks),
|
||||
})
|
||||
} else {
|
||||
common.SuccessResp(c, gin.H{
|
||||
@ -379,13 +379,13 @@ func Link(c *gin.Context) {
|
||||
if storage.Config().OnlyLocal {
|
||||
common.SuccessResp(c, model.Link{
|
||||
URL: fmt.Sprintf("%s/p%s?d&sign=%s",
|
||||
common.GetApiUrl(c.Request),
|
||||
common.GetApiUrl(c),
|
||||
utils.EncodePath(rawPath, true),
|
||||
sign.Sign(rawPath)),
|
||||
})
|
||||
return
|
||||
}
|
||||
link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, HttpReq: c.Request})
|
||||
link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header})
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 500)
|
||||
return
|
||||
|
@ -296,7 +296,7 @@ func FsGet(c *gin.Context) {
|
||||
sign.Sign(reqPath))
|
||||
} else {
|
||||
rawURL = fmt.Sprintf("%s/p%s%s",
|
||||
common.GetApiUrl(c.Request),
|
||||
common.GetApiUrl(c),
|
||||
utils.EncodePath(reqPath, true),
|
||||
query)
|
||||
}
|
||||
@ -309,7 +309,6 @@ func FsGet(c *gin.Context) {
|
||||
link, _, err := fs.Link(c, reqPath, model.LinkArgs{
|
||||
IP: c.ClientIP(),
|
||||
Header: c.Request.Header,
|
||||
HttpReq: c.Request,
|
||||
Redirect: true,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -48,9 +48,9 @@ func verifyState(clientID, ip, state string) bool {
|
||||
|
||||
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
|
||||
if useCompatibility {
|
||||
return common.GetApiUrl(c.Request) + "/api/auth/" + method
|
||||
return common.GetApiUrl(c) + "/api/auth/" + method
|
||||
} else {
|
||||
return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
|
||||
return common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + method
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,7 +236,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||
}
|
||||
if method == "get_sso_id" {
|
||||
if useCompatibility {
|
||||
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
|
||||
c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID)
|
||||
return
|
||||
}
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
@ -263,7 +263,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||
common.ErrorResp(c, err, 400)
|
||||
}
|
||||
if useCompatibility {
|
||||
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token)
|
||||
c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token)
|
||||
return
|
||||
}
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
@ -364,9 +364,9 @@ func SSOLoginCallback(c *gin.Context) {
|
||||
} else {
|
||||
var redirect_uri string
|
||||
if usecompatibility {
|
||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument
|
||||
redirect_uri = common.GetApiUrl(c) + "/api/auth/" + argument
|
||||
} else {
|
||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
|
||||
redirect_uri = common.GetApiUrl(c) + "/api/auth/sso_callback" + "?method=" + argument
|
||||
}
|
||||
resp, err = ssoClient.R().SetHeader("Accept", "application/json").
|
||||
SetFormData(map[string]string{
|
||||
@ -401,7 +401,7 @@ func SSOLoginCallback(c *gin.Context) {
|
||||
}
|
||||
if argument == "get_sso_id" {
|
||||
if usecompatibility {
|
||||
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
|
||||
c.Redirect(302, common.GetApiUrl(c)+"/@manage?sso_id="+userID)
|
||||
return
|
||||
}
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
@ -429,7 +429,7 @@ func SSOLoginCallback(c *gin.Context) {
|
||||
common.ErrorResp(c, err, 400)
|
||||
}
|
||||
if usecompatibility {
|
||||
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token)
|
||||
c.Redirect(302, common.GetApiUrl(c)+"/@login?token="+token)
|
||||
return
|
||||
}
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
|
@ -24,7 +24,7 @@ func BeginAuthnLogin(c *gin.Context) {
|
||||
common.ErrorStrResp(c, "WebAuthn is not enabled", 403)
|
||||
return
|
||||
}
|
||||
authnInstance, err := authn.NewAuthnInstance(c.Request)
|
||||
authnInstance, err := authn.NewAuthnInstance(c)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
@ -65,7 +65,7 @@ func FinishAuthnLogin(c *gin.Context) {
|
||||
common.ErrorStrResp(c, "WebAuthn is not enabled", 403)
|
||||
return
|
||||
}
|
||||
authnInstance, err := authn.NewAuthnInstance(c.Request)
|
||||
authnInstance, err := authn.NewAuthnInstance(c)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
@ -127,7 +127,7 @@ func BeginAuthnRegistration(c *gin.Context) {
|
||||
}
|
||||
user := c.MustGet("user").(*model.User)
|
||||
|
||||
authnInstance, err := authn.NewAuthnInstance(c.Request)
|
||||
authnInstance, err := authn.NewAuthnInstance(c)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
}
|
||||
@ -158,7 +158,7 @@ func FinishAuthnRegistration(c *gin.Context) {
|
||||
user := c.MustGet("user").(*model.User)
|
||||
sessionDataString := c.GetHeader("Session")
|
||||
|
||||
authnInstance, err := authn.NewAuthnInstance(c.Request)
|
||||
authnInstance, err := authn.NewAuthnInstance(c)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
|
@ -10,9 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func StoragesLoaded(c *gin.Context) {
|
||||
if conf.StoragesLoaded {
|
||||
c.Next()
|
||||
} else {
|
||||
if !conf.StoragesLoaded {
|
||||
if utils.SliceContains([]string{"", "/", "/favicon.ico"}, c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
@ -26,5 +24,8 @@ func StoragesLoaded(c *gin.Context) {
|
||||
}
|
||||
common.ErrorStrResp(c, "Loading storage, please wait", 500)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(conf.ApiUrlKey, common.GetApiUrlFormRequest(c.Request))
|
||||
c.Next()
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
"path"
|
||||
@ -11,7 +10,6 @@ import (
|
||||
"github.com/OpenListTeam/OpenList/server/middlewares"
|
||||
|
||||
"github.com/OpenListTeam/OpenList/internal/conf"
|
||||
"github.com/OpenListTeam/OpenList/internal/model"
|
||||
"github.com/OpenListTeam/OpenList/internal/op"
|
||||
"github.com/OpenListTeam/OpenList/internal/setting"
|
||||
"github.com/OpenListTeam/OpenList/server/webdav"
|
||||
@ -45,9 +43,7 @@ func WebDav(dav *gin.RouterGroup) {
|
||||
}
|
||||
|
||||
func ServeWebDAV(c *gin.Context) {
|
||||
user := c.MustGet("user").(*model.User)
|
||||
ctx := context.WithValue(c.Request.Context(), "user", user)
|
||||
handler.ServeHTTP(c.Writer, c.Request.WithContext(ctx))
|
||||
handler.ServeHTTP(c.Writer, c.Request.WithContext(c))
|
||||
}
|
||||
|
||||
func WebDAVAuth(c *gin.Context) {
|
||||
|
@ -241,12 +241,12 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
|
||||
storage, _ := fs.GetStorage(reqPath, &fs.GetStoragesArgs{})
|
||||
downProxyUrl := storage.GetStorage().DownProxyUrl
|
||||
if storage.GetStorage().WebdavNative() || (storage.GetStorage().WebdavProxy() && downProxyUrl == "") {
|
||||
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header, HttpReq: r})
|
||||
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header})
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
if storage.GetStorage().ProxyRange {
|
||||
common.ProxyRange(link, fi.GetSize())
|
||||
common.ProxyRange(ctx, link, fi.GetSize())
|
||||
}
|
||||
err = common.Proxy(w, r, link, fi)
|
||||
if err != nil {
|
||||
@ -260,7 +260,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
|
||||
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
} else {
|
||||
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r, Redirect: true})
|
||||
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, Redirect: true})
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user