perf(link): optimize concurrent response (#641)

* fix(crypt): bug caused by link cache

* perf(crypt,mega,halalcloud,quark,uc): optimize concurrent response link

* chore: 删除无用代码

* ftp

* 修复bug;资源释放

* 添加SyncClosers

* local,sftp,smb

* 重构,优化,增强

* Update internal/stream/util.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com>

* chore

* chore

* 优化,修复bug

* .

---------

Signed-off-by: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
j2rong4cn
2025-07-12 17:57:54 +08:00
committed by GitHub
parent e5fbe72581
commit cc01b410a4
83 changed files with 796 additions and 751 deletions

View File

@ -18,7 +18,6 @@ var config = driver.Config{
Name: "115 Cloud", Name: "115 Cloud",
DefaultRoot: "0", DefaultRoot: "0",
// OnlyProxy: true, // OnlyProxy: true,
// OnlyLocal: true,
// NoOverwriteUpload: true, // NoOverwriteUpload: true,
} }

View File

@ -18,16 +18,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "115 Open", Name: "115 Open",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -19,11 +19,6 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "115 Share", Name: "115 Share",
DefaultRoot: "0", DefaultRoot: "0",
// OnlyProxy: true,
// OnlyLocal: true,
CheckStatus: false,
Alert: "",
NoOverwriteUpload: true,
NoUpload: true, NoUpload: true,
} }

View File

@ -17,15 +17,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "123PanShare", Name: "123PanShare",
LocalSort: true, LocalSort: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: true, NoUpload: true,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -3,6 +3,7 @@ package alias
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
stdpath "path" stdpath "path"
"strings" "strings"
@ -11,8 +12,10 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/fs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/sign"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/OpenList/v4/server/common"
) )
type Alias struct { type Alias struct {
@ -111,21 +114,43 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound
} }
for _, dst := range dsts { for _, dst := range dsts {
link, err := d.link(ctx, dst, sub, args) reqPath := stdpath.Join(dst, sub)
if err == nil { link, file, err := d.link(ctx, reqPath, args)
link.Expiration = nil // 去除非必要缓存d.link里op.Lin有缓存 if err != nil {
if !args.Redirect && len(link.URL) > 0 { continue
// 正常情况下 多并发 仅支持返回URL的驱动 }
// alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发 var resultLink *model.Link
if link != nil {
resultLink = &model.Link{
URL: link.URL,
Header: link.Header,
RangeReader: link.RangeReader,
SyncClosers: utils.NewSyncClosers(link),
}
if link.MFile != nil {
resultLink.RangeReader = &model.FileRangeReader{
RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile),
}
}
} else {
resultLink = &model.Link{
URL: fmt.Sprintf("%s/p%s?sign=%s",
common.GetApiUrl(ctx),
utils.EncodePath(reqPath, true),
sign.Sign(reqPath)),
}
}
if !args.Redirect {
if d.DownloadConcurrency > 0 { if d.DownloadConcurrency > 0 {
link.Concurrency = d.DownloadConcurrency resultLink.Concurrency = d.DownloadConcurrency
} }
if d.DownloadPartSize > 0 { if d.DownloadPartSize > 0 {
link.PartSize = d.DownloadPartSize * utils.KB resultLink.PartSize = d.DownloadPartSize * utils.KB
} }
} }
return link, nil return resultLink, nil
}
} }
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound
} }
@ -251,9 +276,13 @@ func (d *Alias) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer,
reqPath, err := d.getReqPath(ctx, dstDir, true) reqPath, err := d.getReqPath(ctx, dstDir, true)
if err == nil { if err == nil {
if len(reqPath) == 1 { if len(reqPath) == 1 {
return fs.PutDirectly(ctx, *reqPath[0], s) return fs.PutDirectly(ctx, *reqPath[0], &stream.FileStream{
Obj: s,
Mimetype: s.GetMimetype(),
WebPutAsTask: s.NeedStore(),
Reader: s,
})
} else { } else {
defer s.Close()
file, err := s.CacheFullInTempFile() file, err := s.CacheFullInTempFile()
if err != nil { if err != nil {
return err return err
@ -338,14 +367,6 @@ func (d *Alias) Extract(ctx context.Context, obj model.Obj, args model.ArchiveIn
for _, dst := range dsts { for _, dst := range dsts {
link, err := d.extract(ctx, dst, sub, args) link, err := d.extract(ctx, dst, sub, args)
if err == nil { if err == nil {
if !args.Redirect && len(link.URL) > 0 {
if d.DownloadConcurrency > 0 {
link.Concurrency = d.DownloadConcurrency
}
if d.DownloadPartSize > 0 {
link.PartSize = d.DownloadPartSize * utils.KB
}
}
return link, nil return link, nil
} }
} }

View File

@ -96,37 +96,23 @@ func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([
}) })
} }
func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) (*model.Link, error) { func (d *Alias) link(ctx context.Context, reqPath string, args model.LinkArgs) (*model.Link, model.Obj, error) {
reqPath := stdpath.Join(dst, sub)
// 参考 crypt 驱动
storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath) storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
useRawLink := len(common.GetApiUrl(ctx)) == 0 // ftps3 // proxy || ftp,s3
if !useRawLink { if !args.Redirect || len(common.GetApiUrl(ctx)) == 0 {
_, ok := storage.(*Alias) return op.Link(ctx, storage, reqActualPath, args)
useRawLink = !ok && !args.Redirect
} }
if useRawLink { obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})
link, _, err := op.Link(ctx, storage, reqActualPath, args)
return link, err
}
_, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if common.ShouldProxy(storage, stdpath.Base(sub)) { if common.ShouldProxy(storage, stdpath.Base(reqPath)) {
link := &model.Link{ return nil, obj, nil
URL: fmt.Sprintf("%s/p%s?sign=%s",
common.GetApiUrl(ctx),
utils.EncodePath(reqPath, true),
sign.Sign(reqPath)),
} }
return link, nil return op.Link(ctx, storage, reqActualPath, args)
}
link, _, err := op.Link(ctx, storage, reqActualPath, args)
return link, err
} }
func (d *Alias) getReqPath(ctx context.Context, obj model.Obj, isParent bool) ([]*string, error) { func (d *Alias) getReqPath(ctx context.Context, obj model.Obj, isParent bool) ([]*string, error) {

View File

@ -165,7 +165,7 @@ func (d *AliDrive) Remove(ctx context.Context, obj model.Obj) error {
} }
func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.FileStreamer, up driver.UpdateProgress) error { func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.FileStreamer, up driver.UpdateProgress) error {
file := stream.FileStream{ file := &stream.FileStream{
Obj: streamer, Obj: streamer,
Reader: streamer, Reader: streamer,
Mimetype: streamer.GetMimetype(), Mimetype: streamer.GetMimetype(),
@ -209,7 +209,7 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil
io.Closer io.Closer
}{ }{
Reader: io.MultiReader(buf, file), Reader: io.MultiReader(buf, file),
Closer: &file, Closer: file,
} }
} }
} else { } else {

View File

@ -25,12 +25,6 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "AliyundriveOpen", Name: "AliyundriveOpen",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "root", DefaultRoot: "root",
NoOverwriteUpload: true, NoOverwriteUpload: true,
} }

View File

@ -32,7 +32,6 @@ func init() {
config: driver.Config{ config: driver.Config{
Name: "ChaoXingGroupDrive", Name: "ChaoXingGroupDrive",
OnlyProxy: true, OnlyProxy: true,
OnlyLocal: false,
DefaultRoot: "-1", DefaultRoot: "-1",
NoOverwriteUpload: true, NoOverwriteUpload: true,
}, },

View File

@ -26,15 +26,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Cloudreve V4", Name: "Cloudreve V4",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "cloudreve://my", DefaultRoot: "cloudreve://my",
CheckStatus: true, CheckStatus: true,
Alert: "",
NoOverwriteUpload: true, NoOverwriteUpload: true,
} }

View File

@ -1,12 +1,14 @@
package crypt package crypt
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io" "io"
stdpath "path" stdpath "path"
"regexp" "regexp"
"strings" "strings"
"sync"
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
@ -241,6 +243,9 @@ func (d *Crypt) Get(ctx context.Context, path string) (model.Obj, error) {
//return nil, errs.ObjectNotFound //return nil, errs.ObjectNotFound
} }
// https://github.com/rclone/rclone/blob/v1.67.0/backend/crypt/cipher.go#L37
const fileHeaderSize = 32
func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
dstDirActualPath, err := d.getActualPathForRemote(file.GetPath(), false) dstDirActualPath, err := d.getActualPathForRemote(file.GetPath(), false)
if err != nil { if err != nil {
@ -251,58 +256,64 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
return nil, err return nil, err
} }
if remoteLink.RangeReadCloser == nil && remoteLink.MFile == nil && len(remoteLink.URL) == 0 { rrf, err := stream.GetRangeReaderFromLink(remoteFile.GetSize(), remoteLink)
if err != nil {
_ = remoteLink.Close()
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion") return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion")
} }
resultRangeReadCloser := &model.RangeReadCloser{}
resultRangeReadCloser.TryAdd(remoteLink.MFile) mu := &sync.Mutex{}
if remoteLink.RangeReadCloser != nil { var fileHeader []byte
resultRangeReadCloser.AddClosers(remoteLink.RangeReadCloser.GetClosers()) rangeReaderFunc := func(ctx context.Context, offset, limit int64) (io.ReadCloser, error) {
length := limit
if offset == 0 && limit > 0 {
mu.Lock()
if limit <= fileHeaderSize {
defer mu.Unlock()
if fileHeader != nil {
return io.NopCloser(bytes.NewReader(fileHeader[:limit])), nil
} }
remoteFileSize := remoteFile.GetSize() length = fileHeaderSize
rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) { } else if fileHeader == nil {
length := underlyingLength defer mu.Unlock()
if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize { } else {
length = -1 mu.Unlock()
} }
if remoteLink.MFile != nil { }
_, err := remoteLink.MFile.Seek(underlyingOffset, io.SeekStart)
remoteReader, err := rrf.RangeRead(ctx, http_range.Range{Start: offset, Length: length})
if err != nil { if err != nil {
return nil, err return nil, err
} }
//keep reuse same MFile and close at last.
return io.NopCloser(remoteLink.MFile), nil if offset == 0 && limit > 0 {
fileHeader = make([]byte, fileHeaderSize)
n, _ := io.ReadFull(remoteReader, fileHeader)
if n != fileHeaderSize {
fileHeader = nil
return nil, fmt.Errorf("can't read data, expected=%d, got=%d", fileHeaderSize, n)
} }
rrc := remoteLink.RangeReadCloser if limit <= fileHeaderSize {
if rrc == nil && len(remoteLink.URL) > 0 { remoteReader.Close()
var err error return io.NopCloser(bytes.NewReader(fileHeader[:limit])), nil
rrc, err = stream.GetRangeReadCloserFromLink(remoteFileSize, remoteLink) } else {
if err != nil { remoteReader = utils.ReadCloser{
return nil, err Reader: io.MultiReader(bytes.NewReader(fileHeader), remoteReader),
Closer: remoteReader,
} }
resultRangeReadCloser.AddClosers(rrc.GetClosers())
remoteLink.RangeReadCloser = rrc
} }
if rrc != nil {
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length})
if err != nil {
return nil, err
} }
return remoteReader, nil return remoteReader, nil
} }
return nil, errs.NotSupport return &model.Link{
RangeReader: stream.RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
}
resultRangeReadCloser.RangeReader = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length) readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return readSeeker, nil return readSeeker, nil
} }),
SyncClosers: utils.NewSyncClosers(remoteLink),
return &model.Link{
RangeReadCloser: resultRangeReadCloser,
}, nil }, nil
} }

View File

@ -28,15 +28,10 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Crypt", Name: "Crypt",
LocalSort: true, LocalSort: true,
OnlyLocal: true,
OnlyProxy: true, OnlyProxy: true,
NoCache: true, NoCache: true,
NoUpload: false,
NeedMs: false,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: false, NoLinkURL: true,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -18,15 +18,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Doubao", Name: "Doubao",
LocalSort: true, LocalSort: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -14,15 +14,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "DoubaoShare", Name: "DoubaoShare",
LocalSort: true, LocalSort: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: true, NoUpload: true,
NeedMs: false,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -18,13 +18,6 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Dropbox", Name: "Dropbox",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "",
NoOverwriteUpload: true, NoOverwriteUpload: true,
} }

View File

@ -17,16 +17,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "FebBox", Name: "FebBox",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: true, NoUpload: true,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -8,6 +8,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/jlaffaye/ftp" "github.com/jlaffaye/ftp"
) )
@ -26,7 +27,7 @@ func (d *FTP) GetAddition() driver.Additional {
} }
func (d *FTP) Init(ctx context.Context) error { func (d *FTP) Init(ctx context.Context) error {
return d.login() return d._login()
} }
func (d *FTP) Drop(ctx context.Context) error { func (d *FTP) Drop(ctx context.Context) error {
@ -65,15 +66,22 @@ func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
return nil, err return nil, err
} }
r := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) remoteFile := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize())
link := &model.Link{ if remoteFile != nil && !d.Config().OnlyLinkMFile {
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)),
},
SyncClosers: utils.NewSyncClosers(remoteFile),
}, nil
}
return &model.Link{
MFile: &stream.RateLimitFile{ MFile: &stream.RateLimitFile{
File: r, File: remoteFile,
Limiter: stream.ServerDownloadLimit, Limiter: stream.ServerDownloadLimit,
Ctx: ctx, Ctx: ctx,
}, },
} }, nil
return link, nil
} }
func (d *FTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { func (d *FTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {

View File

@ -33,8 +33,9 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "FTP", Name: "FTP",
LocalSort: true, LocalSort: true,
OnlyLocal: true, OnlyLinkMFile: true,
DefaultRoot: "/", DefaultRoot: "/",
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -1,18 +1,28 @@
package ftp package ftp
import ( import (
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
"github.com/jlaffaye/ftp" "github.com/jlaffaye/ftp"
) )
// do others that not defined in Driver interface // do others that not defined in Driver interface
func (d *FTP) login() error { func (d *FTP) login() error {
err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (error, error) {
return d._login(), nil
})
return err
}
func (d *FTP) _login() error {
if d.conn != nil { if d.conn != nil {
_, err := d.conn.CurrentDir() _, err := d.conn.CurrentDir()
if err == nil { if err == nil {

View File

@ -16,16 +16,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "GitHub Releases", Name: "GitHub Releases",
LocalSort: false, NoUpload: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -14,6 +14,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/op"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@ -253,8 +254,8 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
chunks := getChunkSizes(result.Sizes) chunks := getChunkSizes(result.Sizes)
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length length := httpRange.Length
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { if httpRange.Length < 0 || httpRange.Start+httpRange.Length >= size {
length = -1 length = size - httpRange.Start
} }
oo := &openObject{ oo := &openObject{
ctx: ctx, ctx: ctx,
@ -276,9 +277,8 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
duration = time.Until(time.Now().Add(time.Hour)) duration = time.Until(time.Now().Add(time.Hour))
} }
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader}
return &model.Link{ return &model.Link{
RangeReadCloser: resultRangeReadCloser, RangeReader: stream.RateLimitRangeReaderFunc(resultRangeReader),
Expiration: &duration, Expiration: &duration,
}, nil }, nil
} }

View File

@ -19,16 +19,9 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "HalalCloud", Name: "HalalCloud",
LocalSort: false,
OnlyLocal: true,
OnlyProxy: true, OnlyProxy: true,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: false, NoLinkURL: true,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -30,16 +30,7 @@ func init() {
return &ILanZou{ return &ILanZou{
config: driver.Config{ config: driver.Config{
Name: "ILanZou", Name: "ILanZou",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
}, },
conf: Conf{ conf: Conf{
base: "https://api.ilanzou.com", base: "https://api.ilanzou.com",
@ -56,16 +47,7 @@ func init() {
return &ILanZou{ return &ILanZou{
config: driver.Config{ config: driver.Config{
Name: "FeijiPan", Name: "FeijiPan",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
}, },
conf: Conf{ conf: Conf{
base: "https://api.feijipan.com", base: "https://api.feijipan.com",

View File

@ -17,7 +17,6 @@ var config = driver.Config{
Name: "IPFS API", Name: "IPFS API",
DefaultRoot: "/", DefaultRoot: "/",
LocalSort: true, LocalSort: true,
OnlyProxy: false,
} }
func init() { func init() {

View File

@ -15,7 +15,6 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "KodBox", Name: "KodBox",
DefaultRoot: "",
} }
func init() { func init() {

View File

@ -15,15 +15,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "LenovoNasShare", Name: "LenovoNasShare",
LocalSort: true, LocalSort: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: true, NoUpload: true,
NeedMs: false,
DefaultRoot: "",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -19,6 +19,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/internal/sign"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/common"
"github.com/OpenListTeam/times" "github.com/OpenListTeam/times"
@ -220,7 +221,7 @@ func (d *Local) Get(ctx context.Context, path string) (model.Obj, error) {
func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
fullPath := file.GetPath() fullPath := file.GetPath()
var link model.Link link := &model.Link{}
if args.Type == "thumb" && utils.Ext(file.GetName()) != "svg" { if args.Type == "thumb" && utils.Ext(file.GetName()) != "svg" {
var buf *bytes.Buffer var buf *bytes.Buffer
var thumbPath *string var thumbPath *string
@ -252,7 +253,14 @@ func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
} }
link.MFile = open link.MFile = open
} }
return &link, nil if link.MFile != nil && !d.Config().OnlyLinkMFile {
link.AddIfCloser(link.MFile)
link.RangeReader = &model.FileRangeReader{
RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile),
}
link.MFile = nil
}
return link, nil
} }
func (d *Local) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { func (d *Local) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {

View File

@ -18,10 +18,11 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Local", Name: "Local",
OnlyLocal: true, OnlyLinkMFile: false,
LocalSort: true, LocalSort: true,
NoCache: true, NoCache: true,
DefaultRoot: "/", DefaultRoot: "/",
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -14,6 +14,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/t3rm1n4l/go-mega" "github.com/t3rm1n4l/go-mega"
@ -95,8 +96,8 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
size := file.GetSize() size := file.GetSize()
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length length := httpRange.Length
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { if httpRange.Length < 0 || httpRange.Start+httpRange.Length >= size {
length = -1 length = size - httpRange.Start
} }
var down *mega.Download var down *mega.Download
err := utils.Retry(3, time.Second, func() (err error) { err := utils.Retry(3, time.Second, func() (err error) {
@ -114,11 +115,9 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
return readers.NewLimitedReadCloser(oo, length), nil return readers.NewLimitedReadCloser(oo, length), nil
} }
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader} return &model.Link{
resultLink := &model.Link{ RangeReader: stream.RateLimitRangeReaderFunc(resultRangeReader),
RangeReadCloser: resultRangeReadCloser, }, nil
}
return resultLink, nil
} }
return nil, fmt.Errorf("unable to convert dir to mega n") return nil, fmt.Errorf("unable to convert dir to mega n")
} }

View File

@ -18,7 +18,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Mega_nz", Name: "Mega_nz",
LocalSort: true, LocalSort: true,
OnlyLocal: true, OnlyProxy: true,
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -16,16 +16,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Misskey", Name: "Misskey",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -28,7 +28,6 @@ func (a *Addition) GetRootId() string {
var config = driver.Config{ var config = driver.Config{
Name: "MoPan", Name: "MoPan",
// DefaultRoot: "root, / or other",
CheckStatus: true, CheckStatus: true,
Alert: "warning|This network disk may store your password in clear text. Please set your password carefully", Alert: "warning|This network disk may store your password in clear text. Please set your password carefully",
} }

View File

@ -73,7 +73,7 @@ func (d *NeteaseMusic) List(ctx context.Context, dir model.Obj, args model.ListA
func (d *NeteaseMusic) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { func (d *NeteaseMusic) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
if lrc, ok := file.(*LyricObj); ok { if lrc, ok := file.(*LyricObj); ok {
if args.Type == "parsed" { if args.Type == "parsed" && !args.Redirect {
return lrc.getLyricLink(), nil return lrc.getLyricLink(), nil
} else { } else {
return lrc.getProxyLink(ctx), nil return lrc.getProxyLink(ctx), nil

View File

@ -10,6 +10,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/internal/sign"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/OpenList/v4/pkg/utils/random" "github.com/OpenListTeam/OpenList/v4/pkg/utils/random"
"github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/common"
@ -54,7 +55,9 @@ func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link {
func (lrc *LyricObj) getLyricLink() *model.Link { func (lrc *LyricObj) getLyricLink() *model.Link {
return &model.Link{ return &model.Link{
MFile: strings.NewReader(lrc.lyric), RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.GetRangeReaderFromMFile(int64(len(lrc.lyric)), strings.NewReader(lrc.lyric)),
},
} }
} }

View File

@ -22,7 +22,6 @@ var config = driver.Config{
OnlyProxy: true, OnlyProxy: true,
NoUpload: true, NoUpload: true,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: false,
} }
func init() { func init() {

View File

@ -19,7 +19,6 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "PikPak", Name: "PikPak",
LocalSort: true, LocalSort: true,
DefaultRoot: "",
} }
func init() { func init() {

View File

@ -18,7 +18,6 @@ var config = driver.Config{
Name: "PikPakShare", Name: "PikPakShare",
LocalSort: true, LocalSort: true,
NoUpload: true, NoUpload: true,
DefaultRoot: "",
} }
func init() { func init() {

View File

@ -28,7 +28,7 @@ func init() {
return &QuarkOpen{ return &QuarkOpen{
config: driver.Config{ config: driver.Config{
Name: "QuarkOpen", Name: "QuarkOpen",
OnlyLocal: true, OnlyProxy: true,
DefaultRoot: "0", DefaultRoot: "0",
NoOverwriteUpload: true, NoOverwriteUpload: true,
}, },

View File

@ -27,7 +27,6 @@ func init() {
return &QuarkOrUC{ return &QuarkOrUC{
config: driver.Config{ config: driver.Config{
Name: "Quark", Name: "Quark",
OnlyLocal: false,
DefaultRoot: "0", DefaultRoot: "0",
NoOverwriteUpload: true, NoOverwriteUpload: true,
}, },
@ -43,7 +42,7 @@ func init() {
return &QuarkOrUC{ return &QuarkOrUC{
config: driver.Config{ config: driver.Config{
Name: "UC", Name: "UC",
OnlyLocal: true, OnlyProxy: true,
DefaultRoot: "0", DefaultRoot: "0",
NoOverwriteUpload: true, NoOverwriteUpload: true,
}, },

View File

@ -30,7 +30,6 @@ func init() {
return &QuarkUCTV{ return &QuarkUCTV{
config: driver.Config{ config: driver.Config{
Name: "QuarkTV", Name: "QuarkTV",
OnlyLocal: false,
DefaultRoot: "0", DefaultRoot: "0",
NoOverwriteUpload: true, NoOverwriteUpload: true,
NoUpload: true, NoUpload: true,
@ -49,7 +48,6 @@ func init() {
return &QuarkUCTV{ return &QuarkUCTV{
config: driver.Config{ config: driver.Config{
Name: "UCTV", Name: "UCTV",
OnlyLocal: false,
DefaultRoot: "0", DefaultRoot: "0",
NoOverwriteUpload: true, NoOverwriteUpload: true,
NoUpload: true, NoUpload: true,

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/url" "net/url"
stdpath "path" stdpath "path"
"strings" "strings"
@ -158,7 +157,7 @@ func (d *S3) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) e
Name: getPlaceholderName(d.Placeholder), Name: getPlaceholderName(d.Placeholder),
Modified: time.Now(), Modified: time.Now(),
}, },
Reader: io.NopCloser(bytes.NewReader([]byte{})), Reader: bytes.NewReader([]byte{}),
Mimetype: "application/octet-stream", Mimetype: "application/octet-stream",
}, func(float64) {}) }, func(float64) {})
} }

View File

@ -30,7 +30,7 @@ func (d *SFTP) GetAddition() driver.Additional {
} }
func (d *SFTP) Init(ctx context.Context) error { func (d *SFTP) Init(ctx context.Context) error {
return d.initClient() return d._initClient()
} }
func (d *SFTP) Drop(ctx context.Context) error { func (d *SFTP) Drop(ctx context.Context) error {
@ -63,6 +63,14 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
if remoteFile != nil && !d.Config().OnlyLinkMFile {
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)),
},
SyncClosers: utils.NewSyncClosers(remoteFile),
}, nil
}
return &model.Link{ return &model.Link{
MFile: &stream.RateLimitFile{ MFile: &stream.RateLimitFile{
File: remoteFile, File: remoteFile,

View File

@ -18,9 +18,10 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "SFTP", Name: "SFTP",
LocalSort: true, LocalSort: true,
OnlyLocal: true, OnlyLinkMFile: false,
DefaultRoot: "/", DefaultRoot: "/",
CheckStatus: true, CheckStatus: true,
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -1,8 +1,10 @@
package sftp package sftp
import ( import (
"fmt"
"path" "path"
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
"github.com/pkg/sftp" "github.com/pkg/sftp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -11,6 +13,12 @@ import (
// do others that not defined in Driver interface // do others that not defined in Driver interface
func (d *SFTP) initClient() error { func (d *SFTP) initClient() error {
err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("SFTP.initClient:%p", d), func() (error, error) {
return d._initClient(), nil
})
return err
}
func (d *SFTP) _initClient() error {
var auth ssh.AuthMethod var auth ssh.AuthMethod
if len(d.PrivateKey) > 0 { if len(d.PrivateKey) > 0 {
var err error var err error
@ -52,7 +60,9 @@ func (d *SFTP) clientReconnectOnConnectionError() error {
return nil return nil
} }
log.Debugf("[sftp] discarding closed sftp connection: %v", err) log.Debugf("[sftp] discarding closed sftp connection: %v", err)
if d.client != nil {
_ = d.client.Close() _ = d.client.Close()
}
err = d.initClient() err = d.initClient()
return err return err
} }

View File

@ -30,10 +30,10 @@ func (d *SMB) GetAddition() driver.Additional {
} }
func (d *SMB) Init(ctx context.Context) error { func (d *SMB) Init(ctx context.Context) error {
if strings.Index(d.Addition.Address, ":") < 0 { if !strings.Contains(d.Addition.Address, ":") {
d.Addition.Address = d.Addition.Address + ":445" d.Addition.Address = d.Addition.Address + ":445"
} }
return d.initFS() return d._initFS()
} }
func (d *SMB) Drop(ctx context.Context) error { func (d *SMB) Drop(ctx context.Context) error {
@ -81,6 +81,13 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
return nil, err return nil, err
} }
d.updateLastConnTime() d.updateLastConnTime()
if remoteFile != nil && !d.Config().OnlyLinkMFile {
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)),
},
}, nil
}
return &model.Link{ return &model.Link{
MFile: &stream.RateLimitFile{ MFile: &stream.RateLimitFile{
File: remoteFile, File: remoteFile,

View File

@ -16,9 +16,10 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "SMB", Name: "SMB",
LocalSort: true, LocalSort: true,
OnlyLocal: true, OnlyLinkMFile: false,
DefaultRoot: ".", DefaultRoot: ".",
NoCache: true, NoCache: true,
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -1,6 +1,7 @@
package smb package smb
import ( import (
"fmt"
"io/fs" "io/fs"
"net" "net"
"os" "os"
@ -8,6 +9,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/hirochachacha/go-smb2" "github.com/hirochachacha/go-smb2"
@ -26,6 +28,12 @@ func (d *SMB) getLastConnTime() time.Time {
} }
func (d *SMB) initFS() error { func (d *SMB) initFS() error {
err, _, _ := singleflight.ErrorGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (error, error) {
return d._initFS(), nil
})
return err
}
func (d *SMB) _initFS() error {
conn, err := net.Dial("tcp", d.Address) conn, err := net.Dial("tcp", d.Address)
if err != nil { if err != nil {
return err return err

View File

@ -18,8 +18,9 @@ var config = driver.Config{
NoCache: true, NoCache: true,
NoUpload: true, NoUpload: true,
DefaultRoot: "/", DefaultRoot: "/",
OnlyLocal: true, OnlyLinkMFile: true,
OnlyProxy: true, OnlyProxy: true,
NoLinkURL: true,
} }
func init() { func init() {

View File

@ -16,7 +16,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Template", Name: "Template",
LocalSort: false, LocalSort: false,
OnlyLocal: false, OnlyLinkMFile: false,
OnlyProxy: false, OnlyProxy: false,
NoCache: false, NoCache: false,
NoUpload: false, NoUpload: false,
@ -25,6 +25,7 @@ var config = driver.Config{
CheckStatus: false, CheckStatus: false,
Alert: "", Alert: "",
NoOverwriteUpload: false, NoOverwriteUpload: false,
NoLinkURL: false,
} }
func init() { func init() {

View File

@ -85,7 +85,6 @@ func (i *Addition) GetIdentity() string {
var config = driver.Config{ var config = driver.Config{
Name: "ThunderX", Name: "ThunderX",
LocalSort: true, LocalSort: true,
OnlyProxy: false,
} }
var configExpert = driver.Config{ var configExpert = driver.Config{

View File

@ -18,15 +18,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "UrlTree", Name: "UrlTree",
LocalSort: true, LocalSort: true,
OnlyLocal: false,
OnlyProxy: false,
NoCache: true, NoCache: true,
NoUpload: false,
NeedMs: false,
DefaultRoot: "",
CheckStatus: true, CheckStatus: true,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -15,10 +15,10 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "Virtual", Name: "Virtual",
OnlyLocal: true, OnlyLinkMFile: true,
LocalSort: true, LocalSort: true,
NeedMs: true, NeedMs: true,
//NoCache: true, NoLinkURL: true,
} }
func init() { func init() {

View File

@ -15,11 +15,8 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "WeiYun", Name: "WeiYun",
LocalSort: false,
OnlyProxy: true, OnlyProxy: true,
CheckStatus: true, CheckStatus: true,
Alert: "",
NoOverwriteUpload: false,
} }
func init() { func init() {

View File

@ -18,15 +18,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "WoPan", Name: "WoPan",
LocalSort: false,
OnlyLocal: false,
OnlyProxy: false,
NoCache: false,
NoUpload: false,
NeedMs: false,
DefaultRoot: "0", DefaultRoot: "0",
CheckStatus: false,
Alert: "",
NoOverwriteUpload: true, NoOverwriteUpload: true,
} }

View File

@ -3,18 +3,24 @@ package driver
type Config struct { type Config struct {
Name string `json:"name"` Name string `json:"name"`
LocalSort bool `json:"local_sort"` LocalSort bool `json:"local_sort"`
OnlyLocal bool `json:"only_local"` // if the driver returns Link with MFile, this should be set to true
OnlyLinkMFile bool `json:"only_local"`
OnlyProxy bool `json:"only_proxy"` OnlyProxy bool `json:"only_proxy"`
NoCache bool `json:"no_cache"` NoCache bool `json:"no_cache"`
NoUpload bool `json:"no_upload"` NoUpload bool `json:"no_upload"`
NeedMs bool `json:"need_ms"` // if need get message from user, such as validate code // if need get message from user, such as validate code
NeedMs bool `json:"need_ms"`
DefaultRoot string `json:"default_root"` DefaultRoot string `json:"default_root"`
CheckStatus bool `json:"-"` CheckStatus bool `json:"-"`
Alert string `json:"alert"` //info,success,warning,danger //info,success,warning,danger
NoOverwriteUpload bool `json:"-"` // whether to support overwrite upload Alert string `json:"alert"`
// whether to support overwrite upload
NoOverwriteUpload bool `json:"-"`
ProxyRangeOption bool `json:"-"` ProxyRangeOption bool `json:"-"`
// if the driver returns Link without URL, this should be set to true
NoLinkURL bool `json:"-"`
} }
func (c Config) MustProxy() bool { func (c Config) MustProxy() bool {
return c.OnlyProxy || c.OnlyLocal return c.OnlyProxy || c.OnlyLinkMFile || c.NoLinkURL
} }

View File

@ -3,7 +3,6 @@ package fs
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
stdpath "path" stdpath "path"
"time" "time"
@ -86,19 +85,17 @@ func _copy(ctx context.Context, srcObjPath, dstDirPath string, lazyCache ...bool
} }
if !srcObj.IsDir() { if !srcObj.IsDir() {
// copy file directly // copy file directly
link, _, err := op.Link(ctx, srcStorage, srcObjActualPath, model.LinkArgs{ link, _, err := op.Link(ctx, srcStorage, srcObjActualPath, model.LinkArgs{})
Header: http.Header{},
})
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "failed get [%s] link", srcObjPath) return nil, errors.WithMessagef(err, "failed get [%s] link", srcObjPath)
} }
fs := stream.FileStream{ // any link provided is seekable
ss, err := stream.NewSeekableStream(&stream.FileStream{
Obj: srcObj, Obj: srcObj,
Ctx: ctx, Ctx: ctx,
} }, link)
// any link provided is seekable
ss, err := stream.NewSeekableStream(fs, link)
if err != nil { if err != nil {
_ = link.Close()
return nil, errors.WithMessagef(err, "failed get [%s] stream", srcObjPath) return nil, errors.WithMessagef(err, "failed get [%s] stream", srcObjPath)
} }
return nil, op.Put(ctx, dstStorage, dstDirActualPath, ss, nil, false) return nil, op.Put(ctx, dstStorage, dstDirActualPath, ss, nil, false)
@ -165,19 +162,17 @@ func copyFileBetween2Storages(tsk *CopyTask, srcStorage, dstStorage driver.Drive
return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath)
} }
tsk.SetTotalBytes(srcFile.GetSize()) tsk.SetTotalBytes(srcFile.GetSize())
link, _, err := op.Link(tsk.Ctx(), srcStorage, srcFilePath, model.LinkArgs{ link, _, err := op.Link(tsk.Ctx(), srcStorage, srcFilePath, model.LinkArgs{})
Header: http.Header{},
})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
} }
fs := stream.FileStream{ // any link provided is seekable
ss, err := stream.NewSeekableStream(&stream.FileStream{
Obj: srcFile, Obj: srcFile,
Ctx: tsk.Ctx(), Ctx: tsk.Ctx(),
} }, link)
// any link provided is seekable
ss, err := stream.NewSeekableStream(fs, link)
if err != nil { if err != nil {
_ = link.Close()
return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath)
} }
return op.Put(tsk.Ctx(), dstStorage, dstDirPath, ss, tsk.SetProgress, true) return op.Put(tsk.Ctx(), dstStorage, dstDirPath, ss, tsk.SetProgress, true)

View File

@ -3,7 +3,6 @@ package fs
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
stdpath "path" stdpath "path"
"sync" "sync"
"time" "time"
@ -346,23 +345,18 @@ func (t *MoveTask) copyFile(srcStorage, dstStorage driver.Driver, srcFilePath, d
return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath)
} }
link, _, err := op.Link(t.Ctx(), srcStorage, srcFilePath, model.LinkArgs{ link, _, err := op.Link(t.Ctx(), srcStorage, srcFilePath, model.LinkArgs{})
Header: http.Header{},
})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
} }
ss, err := stream.NewSeekableStream(&stream.FileStream{
fs := stream.FileStream{
Obj: srcFile, Obj: srcFile,
Ctx: t.Ctx(), Ctx: t.Ctx(),
} }, link)
ss, err := stream.NewSeekableStream(fs, link)
if err != nil { if err != nil {
_ = link.Close()
return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath)
} }
return op.Put(t.Ctx(), dstStorage, dstDirPath, ss, nil, true) return op.Put(t.Ctx(), dstStorage, dstDirPath, ss, nil, true)
} }

View File

@ -10,8 +10,8 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/op"
"github.com/OpenListTeam/OpenList/v4/internal/task" "github.com/OpenListTeam/OpenList/v4/internal/task"
"github.com/pkg/errors"
"github.com/OpenListTeam/tache" "github.com/OpenListTeam/tache"
"github.com/pkg/errors"
) )
type UploadTask struct { type UploadTask struct {
@ -73,9 +73,11 @@ func putAsTask(ctx context.Context, dstDirPath string, file model.FileStreamer)
func putDirectly(ctx context.Context, dstDirPath string, file model.FileStreamer, lazyCache ...bool) error { func putDirectly(ctx context.Context, dstDirPath string, file model.FileStreamer, lazyCache ...bool) error {
storage, dstDirActualPath, err := op.GetStorageAndActualPath(dstDirPath) storage, dstDirActualPath, err := op.GetStorageAndActualPath(dstDirPath)
if err != nil { if err != nil {
_ = file.Close()
return errors.WithMessage(err, "failed get storage") return errors.WithMessage(err, "failed get storage")
} }
if storage.Config().NoUpload { if storage.Config().NoUpload {
_ = file.Close()
return errors.WithStack(errs.UploadNotSupported) return errors.WithStack(errs.UploadNotSupported)
} }
return op.Put(ctx, storage, dstDirActualPath, file, nil, lazyCache...) return op.Put(ctx, storage, dstDirActualPath, file, nil, lazyCache...)

View File

@ -2,6 +2,7 @@ package model
import ( import (
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"time" "time"
@ -26,14 +27,23 @@ type LinkArgs struct {
type Link struct { type Link struct {
URL string `json:"url"` // most common way URL string `json:"url"` // most common way
Header http.Header `json:"header"` // needed header (for url) Header http.Header `json:"header"` // needed header (for url)
RangeReadCloser RangeReadCloserIF `json:"-"` // recommended way if can't use URL RangeReader RangeReaderIF `json:"-"` // recommended way if can't use URL
MFile io.ReadSeeker `json:"-"` // best for local,smb... file system, which exposes MFile MFile File `json:"-"` // best for local,smb... file system, which exposes MFile
Expiration *time.Duration // local cache expire Duration Expiration *time.Duration // local cache expire Duration
//for accelerating request, use multi-thread downloading //for accelerating request, use multi-thread downloading
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
PartSize int `json:"part_size"` PartSize int `json:"part_size"`
utils.SyncClosers `json:"-"`
}
func (l *Link) Close() error {
if clr, ok := l.MFile.(io.Closer); ok {
return errors.Join(clr.Close(), l.SyncClosers.Close())
}
return l.SyncClosers.Close()
} }
type OtherArgs struct { type OtherArgs struct {
@ -74,23 +84,24 @@ type ArchiveDecompressArgs struct {
PutIntoNewDir bool PutIntoNewDir bool
} }
type RangeReadCloserIF interface { type RangeReaderIF interface {
RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error)
}
type RangeReadCloserIF interface {
RangeReaderIF
utils.ClosersIF utils.ClosersIF
} }
var _ RangeReadCloserIF = (*RangeReadCloser)(nil) var _ RangeReadCloserIF = (*RangeReadCloser)(nil)
type RangeReadCloser struct { type RangeReadCloser struct {
RangeReader RangeReaderFunc RangeReader RangeReaderIF
utils.Closers utils.Closers
} }
func (r *RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { func (r *RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
rc, err := r.RangeReader(ctx, httpRange) rc, err := r.RangeReader.RangeRead(ctx, httpRange)
r.Closers.Add(rc) r.Add(rc)
return rc, err return rc, err
} }
// type WriterFunc func(w io.Writer) error
type RangeReaderFunc func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error)

View File

@ -1,6 +1,9 @@
package model package model
import "io" import (
"errors"
"io"
)
// File is basic file level accessing interface // File is basic file level accessing interface
type File interface { type File interface {
@ -8,3 +11,22 @@ type File interface {
io.ReaderAt io.ReaderAt
io.Seeker io.Seeker
} }
type FileCloser struct {
File
io.Closer
}
func (f *FileCloser) Close() error {
var errs []error
if clr, ok := f.File.(io.Closer); ok {
errs = append(errs, clr.Close())
}
if f.Closer != nil {
errs = append(errs, f.Closer.Close())
}
return errors.Join(errs...)
}
type FileRangeReader struct {
RangeReaderIF
}

View File

@ -37,7 +37,7 @@ type Obj interface {
// FileStreamer ->check FileStream for more comments // FileStreamer ->check FileStream for more comments
type FileStreamer interface { type FileStreamer interface {
io.Reader io.Reader
io.Closer utils.ClosersIF
Obj Obj
GetMimetype() string GetMimetype() string
//SetReader(io.Reader) //SetReader(io.Reader)

View File

@ -12,6 +12,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/http_range"
@ -70,7 +71,7 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo
var finalP HttpRequestParams var finalP HttpRequestParams
awsutil.Copy(&finalP, p) awsutil.Copy(&finalP, p)
if finalP.Range.Length == -1 { if finalP.Range.Length < 0 || finalP.Range.Start+finalP.Range.Length > finalP.Size {
finalP.Range.Length = finalP.Size - finalP.Range.Start finalP.Range.Length = finalP.Size - finalP.Range.Start
} }
impl := downloader{params: &finalP, cfg: d, ctx: ctx} impl := downloader{params: &finalP, cfg: d, ctx: ctx}
@ -120,7 +121,7 @@ type ConcurrencyLimit struct {
Limit int // 需要大于0 Limit int // 需要大于0
} }
var ErrExceedMaxConcurrency = errors.New("ExceedMaxConcurrency") var ErrExceedMaxConcurrency = ErrorHttpStatusCode(http.StatusTooManyRequests)
func (l *ConcurrencyLimit) sub() error { func (l *ConcurrencyLimit) sub() error {
l._m.Lock() l._m.Lock()
@ -181,6 +182,7 @@ func (d *downloader) download() (io.ReadCloser, error) {
resp.Body = utils.NewReadCloser(resp.Body, func() error { resp.Body = utils.NewReadCloser(resp.Body, func() error {
d.m.Lock() d.m.Lock()
defer d.m.Unlock() defer d.m.Unlock()
var err error
if closeFunc != nil { if closeFunc != nil {
d.concurrencyFinish() d.concurrencyFinish()
err = closeFunc() err = closeFunc()
@ -199,7 +201,7 @@ func (d *downloader) download() (io.ReadCloser, error) {
d.pos = d.params.Range.Start d.pos = d.params.Range.Start
d.maxPos = d.params.Range.Start + d.params.Range.Length d.maxPos = d.params.Range.Start + d.params.Range.Length
d.concurrency = d.cfg.Concurrency d.concurrency = d.cfg.Concurrency
d.sendChunkTask(true) _ = d.sendChunkTask(true)
var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf)
@ -303,7 +305,7 @@ func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) {
return true, nil return true, nil
} }
d.sendChunkTask(false) _ = d.sendChunkTask(false)
d.readingID = id d.readingID = id
return false, d.getBuf(id) return false, d.getBuf(id)
@ -398,14 +400,15 @@ var errInfiniteRetry = errors.New("infinite retry")
func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) {
resp, err := d.cfg.HttpClient(d.ctx, params) resp, err := d.cfg.HttpClient(d.ctx, params)
if err != nil { if err != nil {
if resp == nil { statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode)
if !ok {
return 0, err return 0, err
} }
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { if statusCode == http.StatusRequestedRangeNotSatisfiable {
return 0, err return 0, err
} }
if ch.id == 0 { //第1个任务 有限的重试,超过重试就会结束请求 if ch.id == 0 { //第1个任务 有限的重试,超过重试就会结束请求
switch resp.StatusCode { switch statusCode {
default: default:
return 0, err return 0, err
case http.StatusTooManyRequests: case http.StatusTooManyRequests:
@ -414,7 +417,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
case http.StatusGatewayTimeout: case http.StatusGatewayTimeout:
} }
<-time.After(time.Millisecond * 200) <-time.After(time.Millisecond * 200)
return 0, &errNeedRetry{err: fmt.Errorf("http request failure,status: %d", resp.StatusCode)} return 0, &errNeedRetry{err: err}
} }
// 来到这 说明第1个分片下载 连接成功了 // 来到这 说明第1个分片下载 连接成功了
@ -450,7 +453,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
return 0, err return 0, err
} }
} }
d.sendChunkTask(true) _ = d.sendChunkTask(true)
n, err := utils.CopyWithBuffer(ch.buf, resp.Body) n, err := utils.CopyWithBuffer(ch.buf, resp.Body)
if err != nil { if err != nil {
@ -552,12 +555,26 @@ type chunk struct {
func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef) header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef)
return RequestHttp(ctx, "GET", header, params.URL)
}
res, err := RequestHttp(ctx, "GET", header, params.URL) func GetRangeReaderHttpRequestFunc(rangeReader model.RangeReaderIF) HttpRequestFunc {
return func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
rc, err := rangeReader.RangeRead(ctx, params.Range)
if err != nil { if err != nil {
return res, err return nil, err
}
return &http.Response{
StatusCode: http.StatusPartialContent,
Status: http.StatusText(http.StatusPartialContent),
Body: rc,
Header: http.Header{
"Content-Range": {params.Range.ContentRange(params.Size)},
},
ContentLength: params.Range.Length,
}, nil
} }
return res, nil
} }
type HttpRequestParams struct { type HttpRequestParams struct {

View File

@ -114,14 +114,14 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
// 使用请求的Context // 使用请求的Context
// 不然从sendContent读不到数据即使请求断开CopyBuffer也会一直堵塞 // 不然从sendContent读不到数据即使请求断开CopyBuffer也会一直堵塞
ctx := context.WithValue(r.Context(), "request_header", r.Header) ctx := r.Context()
switch { switch {
case len(ranges) == 0: case len(ranges) == 0:
reader, err := RangeReadCloser.RangeRead(ctx, http_range.Range{Length: -1}) reader, err := RangeReadCloser.RangeRead(ctx, http_range.Range{Length: -1})
if err != nil { if err != nil {
code = http.StatusRequestedRangeNotSatisfiable code = http.StatusRequestedRangeNotSatisfiable
if errors.Is(err, ErrExceedMaxConcurrency) { if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok {
code = http.StatusTooManyRequests code = int(statusCode)
} }
http.Error(w, err.Error(), code) http.Error(w, err.Error(), code)
return nil return nil
@ -143,8 +143,8 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
sendContent, err = RangeReadCloser.RangeRead(ctx, ra) sendContent, err = RangeReadCloser.RangeRead(ctx, ra)
if err != nil { if err != nil {
code = http.StatusRequestedRangeNotSatisfiable code = http.StatusRequestedRangeNotSatisfiable
if errors.Is(err, ErrExceedMaxConcurrency) { if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok {
code = http.StatusTooManyRequests code = int(statusCode)
} }
http.Error(w, err.Error(), code) http.Error(w, err.Error(), code)
return nil return nil
@ -205,8 +205,8 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize) log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize)
} }
code = http.StatusInternalServerError code = http.StatusInternalServerError
if errors.Is(err, ErrExceedMaxConcurrency) { if statusCode, ok := errors.Unwrap(err).(ErrorHttpStatusCode); ok {
code = http.StatusTooManyRequests code = int(statusCode)
} }
w.WriteHeader(code) w.WriteHeader(code)
return err return err
@ -259,11 +259,17 @@ func RequestHttp(ctx context.Context, httpMethod string, headerOverride http.Hea
_ = res.Body.Close() _ = res.Body.Close()
msg := string(all) msg := string(all)
log.Debugln(msg) log.Debugln(msg)
return res, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg) return nil, fmt.Errorf("http request [%s] failure,status: %w response:%s", URL, ErrorHttpStatusCode(res.StatusCode), msg)
} }
return res, nil return res, nil
} }
type ErrorHttpStatusCode int
func (e ErrorHttpStatusCode) Error() string {
return fmt.Sprintf("%d|%s", e, http.StatusText(int(e)))
}
var once sync.Once var once sync.Once
var httpClient *http.Client var httpClient *http.Client

View File

@ -350,3 +350,5 @@ func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.Rea
// return an io.ReadCloser that is limited to `length` bytes. // return an io.ReadCloser that is limited to `length` bytes.
return &LimitedReadCloser{readCloser, length_int}, nil return &LimitedReadCloser{readCloser, length_int}, nil
} }
type RequestHeaderKey struct{}

View File

@ -3,7 +3,6 @@ package tool
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"os" "os"
stdpath "path" stdpath "path"
"path/filepath" "path/filepath"
@ -43,11 +42,11 @@ func (t *TransferTask) Run() error {
defer func() { t.SetEndTime(time.Now()) }() defer func() { t.SetEndTime(time.Now()) }()
if t.SrcStorage == nil { if t.SrcStorage == nil {
if t.DeletePolicy == UploadDownloadStream { if t.DeletePolicy == UploadDownloadStream {
rrc, err := stream.GetRangeReadCloserFromLink(t.GetTotalBytes(), &model.Link{URL: t.Url}) rr, err := stream.GetRangeReaderFromLink(t.GetTotalBytes(), &model.Link{URL: t.Url})
if err != nil { if err != nil {
return err return err
} }
r, err := rrc.RangeRead(t.Ctx(), http_range.Range{Length: t.GetTotalBytes()}) r, err := rr.RangeRead(t.Ctx(), http_range.Range{Length: t.GetTotalBytes()})
if err != nil { if err != nil {
return err return err
} }
@ -63,9 +62,8 @@ func (t *TransferTask) Run() error {
}, },
Reader: r, Reader: r,
Mimetype: mimetype, Mimetype: mimetype,
Closers: utils.NewClosers(rrc), Closers: utils.NewClosers(r),
} }
defer s.Close()
return op.Put(t.Ctx(), t.DstStorage, t.DstDirPath, s, t.SetProgress) return op.Put(t.Ctx(), t.DstStorage, t.DstDirPath, s, t.SetProgress)
} }
return transferStdPath(t) return transferStdPath(t)
@ -279,19 +277,17 @@ func transferObjFile(t *TransferTask) error {
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", t.SrcObjPath) return errors.WithMessagef(err, "failed get src [%s] file", t.SrcObjPath)
} }
link, _, err := op.Link(t.Ctx(), t.SrcStorage, t.SrcObjPath, model.LinkArgs{ link, _, err := op.Link(t.Ctx(), t.SrcStorage, t.SrcObjPath, model.LinkArgs{})
Header: http.Header{},
})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", t.SrcObjPath) return errors.WithMessagef(err, "failed get [%s] link", t.SrcObjPath)
} }
fs := stream.FileStream{ // any link provided is seekable
ss, err := stream.NewSeekableStream(&stream.FileStream{
Obj: srcFile, Obj: srcFile,
Ctx: t.Ctx(), Ctx: t.Ctx(),
} }, link)
// any link provided is seekable
ss, err := stream.NewSeekableStream(fs, link)
if err != nil { if err != nil {
_ = link.Close()
return errors.WithMessagef(err, "failed get [%s] stream", t.SrcObjPath) return errors.WithMessagef(err, "failed get [%s] stream", t.SrcObjPath)
} }
t.SetTotalBytes(srcFile.GetSize()) t.SetTotalBytes(srcFile.GetSize())

View File

@ -31,12 +31,6 @@ func GetArchiveMeta(ctx context.Context, storage driver.Driver, path string, arg
} }
path = utils.FixAndCleanPath(path) path = utils.FixAndCleanPath(path)
key := Key(storage, path) key := Key(storage, path)
if !args.Refresh {
if meta, ok := archiveMetaCache.Get(key); ok {
log.Debugf("use cache when get %s archive meta", path)
return meta, nil
}
}
fn := func() (*model.ArchiveMetaProvider, error) { fn := func() (*model.ArchiveMetaProvider, error) {
_, m, err := getArchiveMeta(ctx, storage, path, args) _, m, err := getArchiveMeta(ctx, storage, path, args)
if err != nil { if err != nil {
@ -47,10 +41,16 @@ func GetArchiveMeta(ctx context.Context, storage driver.Driver, path string, arg
} }
return m, nil return m, nil
} }
if storage.Config().OnlyLocal { if storage.Config().OnlyLinkMFile {
meta, err := fn() meta, err := fn()
return meta, err return meta, err
} }
if !args.Refresh {
if meta, ok := archiveMetaCache.Get(key); ok {
log.Debugf("use cache when get %s archive meta", path)
return meta, nil
}
}
meta, err, _ := archiveMetaG.Do(key, fn) meta, err, _ := archiveMetaG.Do(key, fn)
return meta, err return meta, err
} }
@ -62,12 +62,7 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
} }
baseName, ext, found := strings.Cut(obj.GetName(), ".") baseName, ext, found := strings.Cut(obj.GetName(), ".")
if !found { if !found {
if clr, ok := l.MFile.(io.Closer); ok { _ = l.Close()
_ = clr.Close()
}
if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close()
}
return nil, nil, nil, errors.Errorf("failed get archive tool: the obj does not have an extension.") return nil, nil, nil, errors.Errorf("failed get archive tool: the obj does not have an extension.")
} }
partExt, t, err := tool.GetArchiveTool("." + ext) partExt, t, err := tool.GetArchiveTool("." + ext)
@ -75,23 +70,13 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
var e error var e error
partExt, t, e = tool.GetArchiveTool(stdpath.Ext(obj.GetName())) partExt, t, e = tool.GetArchiveTool(stdpath.Ext(obj.GetName()))
if e != nil { if e != nil {
if clr, ok := l.MFile.(io.Closer); ok { _ = l.Close()
_ = clr.Close()
}
if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close()
}
return nil, nil, nil, errors.WithMessagef(stderrors.Join(err, e), "failed get archive tool: %s", ext) return nil, nil, nil, errors.WithMessagef(stderrors.Join(err, e), "failed get archive tool: %s", ext)
} }
} }
ss, err := stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: obj}, l) ss, err := stream.NewSeekableStream(&stream.FileStream{Ctx: ctx, Obj: obj}, l)
if err != nil { if err != nil {
if clr, ok := l.MFile.(io.Closer); ok { _ = l.Close()
_ = clr.Close()
}
if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close()
}
return nil, nil, nil, errors.WithMessagef(err, "failed get [%s] stream", path) return nil, nil, nil, errors.WithMessagef(err, "failed get [%s] stream", path)
} }
ret := []*stream.SeekableStream{ss} ret := []*stream.SeekableStream{ss}
@ -107,14 +92,9 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
if err != nil { if err != nil {
break break
} }
ss, err = stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: o}, l) ss, err = stream.NewSeekableStream(&stream.FileStream{Ctx: ctx, Obj: o}, l)
if err != nil { if err != nil {
if clr, ok := l.MFile.(io.Closer); ok { _ = l.Close()
_ = clr.Close()
}
if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close()
}
for _, s := range ret { for _, s := range ret {
_ = s.Close() _ = s.Close()
} }
@ -375,12 +355,12 @@ func ArchiveGet(ctx context.Context, storage driver.Driver, path string, args mo
} }
type extractLink struct { type extractLink struct {
Link *model.Link *model.Link
Obj model.Obj Obj model.Obj
} }
var extractCache = cache.NewMemCache(cache.WithShards[*extractLink](16)) var extractCache = cache.NewMemCache(cache.WithShards[*extractLink](16))
var extractG singleflight.Group[*extractLink] var extractG = singleflight.Group[*extractLink]{Remember: true}
func DriverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*model.Link, model.Obj, error) { func DriverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*model.Link, model.Obj, error) {
if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { if storage.Config().CheckStatus && storage.GetStorage().Status != WORK {
@ -389,9 +369,9 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
key := stdpath.Join(Key(storage, path), args.InnerPath) key := stdpath.Join(Key(storage, path), args.InnerPath)
if link, ok := extractCache.Get(key); ok { if link, ok := extractCache.Get(key); ok {
return link.Link, link.Obj, nil return link.Link, link.Obj, nil
} else if link, ok := extractCache.Get(key + ":" + args.IP); ok {
return link.Link, link.Obj, nil
} }
var forget utils.CloseFunc
fn := func() (*extractLink, error) { fn := func() (*extractLink, error) {
link, err := driverExtract(ctx, storage, path, args) link, err := driverExtract(ctx, storage, path, args)
if err != nil { if err != nil {
@ -400,16 +380,33 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
if link.Link.Expiration != nil { if link.Link.Expiration != nil {
extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration)) extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration))
} }
link.Add(forget)
return link, nil return link, nil
} }
if storage.Config().OnlyLocal {
if storage.Config().OnlyLinkMFile {
link, err := fn() link, err := fn()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return link.Link, link.Obj, nil return link.Link, link.Obj, nil
} }
forget = func() error {
if forget != nil {
forget = nil
linkG.Forget(key)
}
return nil
}
link, err, _ := extractG.Do(key, fn) link, err, _ := extractG.Do(key, fn)
if err == nil && !link.AcquireReference() {
link, err, _ = extractG.Do(key, fn)
if err == nil {
link.AcquireReference()
}
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -81,7 +81,15 @@ func getMainItems(config driver.Config) []driver.Item {
Help: "The cache expiration time for this storage", Help: "The cache expiration time for this storage",
}) })
} }
if !config.OnlyProxy && !config.OnlyLocal { if config.MustProxy() {
items = append(items, driver.Item{
Name: "webdav_policy",
Type: conf.TypeSelect,
Default: "native_proxy",
Options: "use_proxy_url,native_proxy",
Required: true,
})
} else {
items = append(items, []driver.Item{{ items = append(items, []driver.Item{{
Name: "web_proxy", Name: "web_proxy",
Type: conf.TypeBool, Type: conf.TypeBool,
@ -104,14 +112,6 @@ func getMainItems(config driver.Config) []driver.Item {
} }
items = append(items, item) items = append(items, item)
} }
} else {
items = append(items, driver.Item{
Name: "webdav_policy",
Type: conf.TypeSelect,
Default: "native_proxy",
Options: "use_proxy_url,native_proxy",
Required: true,
})
} }
items = append(items, driver.Item{ items = append(items, driver.Item{
Name: "down_proxy_url", Name: "down_proxy_url",

View File

@ -244,7 +244,7 @@ func GetUnwrap(ctx context.Context, storage driver.Driver, path string) (model.O
} }
var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16)) var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16))
var linkG singleflight.Group[*model.Link] var linkG = singleflight.Group[*model.Link]{Remember: true}
// Link get link, if is an url. should have an expiry time // Link get link, if is an url. should have an expiry time
func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
@ -262,6 +262,8 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
if link, ok := linkCache.Get(key); ok { if link, ok := linkCache.Get(key); ok {
return link, file, nil return link, file, nil
} }
var forget utils.CloseFunc
fn := func() (*model.Link, error) { fn := func() (*model.Link, error) {
link, err := storage.Link(ctx, file, args) link, err := storage.Link(ctx, file, args)
if err != nil { if err != nil {
@ -270,15 +272,29 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
if link.Expiration != nil { if link.Expiration != nil {
linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration))
} }
link.Add(forget)
return link, nil return link, nil
} }
if storage.Config().OnlyLocal { if storage.Config().OnlyLinkMFile {
link, err := fn() link, err := fn()
return link, file, err return link, file, err
} }
forget = func() error {
if forget != nil {
forget = nil
linkG.Forget(key)
}
return nil
}
link, err, _ := linkG.Do(key, fn) link, err, _ := linkG.Do(key, fn)
if err == nil && !link.AcquireReference() {
link, err, _ = linkG.Do(key, fn)
if err == nil {
link.AcquireReference()
}
}
return link, file, err return link, file, err
} }
@ -507,14 +523,15 @@ func Remove(ctx context.Context, storage driver.Driver, path string) error {
} }
func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file model.FileStreamer, up driver.UpdateProgress, lazyCache ...bool) error { func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file model.FileStreamer, up driver.UpdateProgress, lazyCache ...bool) error {
if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { close := file.Close
return errors.Errorf("storage not init: %s", storage.GetStorage().Status)
}
defer func() { defer func() {
if err := file.Close(); err != nil { if err := close(); err != nil {
log.Errorf("failed to close file streamer, %v", err) log.Errorf("failed to close file streamer, %v", err)
} }
}() }()
if storage.Config().CheckStatus && storage.GetStorage().Status != WORK {
return errors.Errorf("storage not init: %s", storage.GetStorage().Status)
}
// UrlTree PUT // UrlTree PUT
if storage.GetStorage().Driver == "UrlTree" { if storage.GetStorage().Driver == "UrlTree" {
var link string var link string

View File

@ -142,19 +142,19 @@ func (r *RateLimitFile) Close() error {
return nil return nil
} }
type RateLimitRangeReadCloser struct { type RateLimitRangeReaderFunc RangeReaderFunc
model.RangeReadCloserIF
Limiter Limiter
}
func (rrc *RateLimitRangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { func (f RateLimitRangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
rc, err := rrc.RangeReadCloserIF.RangeRead(ctx, httpRange) rc, err := f(ctx, httpRange)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &RateLimitReader{ if ServerDownloadLimit != nil {
Reader: rc, rc = &RateLimitReader{
Limiter: rrc.Limiter,
Ctx: ctx, Ctx: ctx,
}, nil Reader: rc,
Limiter: ServerDownloadLimit,
}
}
return rc, nil
} }

View File

@ -110,8 +110,7 @@ const InMemoryBufMaxSizeBytes = InMemoryBufMaxSize * 1024 * 1024
// RangeRead have to cache all data first since only Reader is provided. // RangeRead have to cache all data first since only Reader is provided.
// also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory // also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory
func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
if httpRange.Length == -1 { if httpRange.Length < 0 || httpRange.Start+httpRange.Length > f.GetSize() {
// 参考 internal/net/request.go
httpRange.Length = f.GetSize() - httpRange.Start httpRange.Length = f.GetSize() - httpRange.Start
} }
var cache io.ReaderAt = f.GetFile() var cache io.ReaderAt = f.GetFile()
@ -159,47 +158,40 @@ var _ model.FileStreamer = (*FileStream)(nil)
// additional resources that need to be closed, they should be added to the Closer property of // additional resources that need to be closed, they should be added to the Closer property of
// the SeekableStream object and be closed together when the SeekableStream object is closed. // the SeekableStream object and be closed together when the SeekableStream object is closed.
type SeekableStream struct { type SeekableStream struct {
FileStream *FileStream
// should have one of belows to support rangeRead // should have one of belows to support rangeRead
rangeReadCloser model.RangeReadCloserIF rangeReadCloser model.RangeReadCloserIF
} }
func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) { func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error) {
if len(fs.Mimetype) == 0 { if len(fs.Mimetype) == 0 {
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
} }
ss := &SeekableStream{FileStream: fs}
if ss.Reader != nil { if fs.Reader != nil {
ss.TryAdd(ss.Reader) fs.Add(link)
return ss, nil return &SeekableStream{FileStream: fs}, nil
} }
if link != nil { if link != nil {
if link.MFile != nil { rr, err := GetRangeReaderFromLink(fs.GetSize(), link)
ss.Closers.TryAdd(link.MFile)
ss.Reader = link.MFile
return ss, nil
}
if link.RangeReadCloser != nil {
ss.rangeReadCloser = &RateLimitRangeReadCloser{
RangeReadCloserIF: link.RangeReadCloser,
Limiter: ServerDownloadLimit,
}
ss.Add(ss.rangeReadCloser)
return ss, nil
}
if len(link.URL) > 0 {
rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rrc = &RateLimitRangeReadCloser{ if _, ok := rr.(*model.FileRangeReader); ok {
RangeReadCloserIF: rrc, fs.Reader, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1})
Limiter: ServerDownloadLimit, if err != nil {
return nil, err
} }
ss.rangeReadCloser = rrc fs.Add(link)
ss.Add(rrc) return &SeekableStream{FileStream: fs}, nil
return ss, nil
} }
rrc := &model.RangeReadCloser{
RangeReader: rr,
}
fs.Add(link)
fs.Add(rrc)
return &SeekableStream{FileStream: fs, rangeReadCloser: rrc}, nil
} }
return nil, fmt.Errorf("illegal seekableStream") return nil, fmt.Errorf("illegal seekableStream")
} }
@ -211,9 +203,6 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
// RangeRead is not thread-safe, pls use it in single thread only. // RangeRead is not thread-safe, pls use it in single thread only.
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
if ss.tmpFile == nil && ss.rangeReadCloser != nil { if ss.tmpFile == nil && ss.rangeReadCloser != nil {
if httpRange.Length == -1 {
httpRange.Length = ss.GetSize() - httpRange.Start
}
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange) rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange)
if err != nil { if err != nil {
return nil, err return nil, err
@ -229,10 +218,6 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro
// only provide Reader as full stream when it's demanded. in rapid-upload, we can skip this to save memory // only provide Reader as full stream when it's demanded. in rapid-upload, we can skip this to save memory
func (ss *SeekableStream) Read(p []byte) (n int, err error) { func (ss *SeekableStream) Read(p []byte) (n int, err error) {
//f.mu.Lock()
//f.peekedOnce = true
//defer f.mu.Unlock()
if ss.Reader == nil { if ss.Reader == nil {
if ss.rangeReadCloser == nil { if ss.rangeReadCloser == nil {
return 0, fmt.Errorf("illegal seekableStream") return 0, fmt.Errorf("illegal seekableStream")
@ -241,7 +226,7 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
if err != nil { if err != nil {
return 0, nil return 0, nil
} }
ss.Reader = io.NopCloser(rc) ss.Reader = rc
} }
return ss.Reader.Read(p) return ss.Reader.Read(p)
} }
@ -496,7 +481,7 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
return r.masterOff, errors.New("invalid seek: negative position") return r.masterOff, errors.New("invalid seek: negative position")
} }
if offset > r.ss.GetSize() { if offset > r.ss.GetSize() {
return r.masterOff, io.EOF offset = r.ss.GetSize()
} }
r.masterOff = offset r.masterOff = offset
return offset, nil return offset, nil

View File

@ -3,6 +3,7 @@ package stream
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -14,57 +15,93 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCloserIF, error) { type RangeReaderFunc func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error)
if len(link.URL) == 0 {
return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link") func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
return f(ctx, httpRange)
}
func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) {
if link.MFile != nil {
return &model.FileRangeReader{RangeReaderIF: GetRangeReaderFromMFile(size, link.MFile)}, nil
} }
rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) {
if link.Concurrency > 0 || link.PartSize > 0 { if link.Concurrency > 0 || link.PartSize > 0 {
header := net.ProcessHeader(nil, link.Header)
down := net.NewDownloader(func(d *net.Downloader) { down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency d.Concurrency = link.Concurrency
d.PartSize = link.PartSize d.PartSize = link.PartSize
}) })
req := &net.HttpRequestParams{ var rangeReader RangeReaderFunc = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
URL: link.URL, var req *net.HttpRequestParams
Range: r, if link.RangeReader != nil {
req = &net.HttpRequestParams{
Range: httpRange,
Size: size, Size: size,
}
} else {
requestHeader, _ := ctx.Value(net.RequestHeaderKey{}).(http.Header)
header := net.ProcessHeader(requestHeader, link.Header)
req = &net.HttpRequestParams{
Range: httpRange,
Size: size,
URL: link.URL,
HeaderRef: header, HeaderRef: header,
} }
rc, err := down.Download(ctx, req) }
return rc, err return down.Download(ctx, req)
}
if link.RangeReader != nil {
down.HttpClient = net.GetRangeReaderHttpRequestFunc(link.RangeReader)
return rangeReader, nil
}
return RateLimitRangeReaderFunc(rangeReader), nil
}
if link.RangeReader != nil {
return link.RangeReader, nil
} }
response, err := RequestRangedHttp(ctx, link, r.Start, r.Length)
if len(link.URL) == 0 {
return nil, errors.New("invalid link: must have at least one of MFile, URL, or RangeReader")
}
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
if httpRange.Length < 0 || httpRange.Start+httpRange.Length > size {
httpRange.Length = size - httpRange.Start
}
requestHeader, _ := ctx.Value(net.RequestHeaderKey{}).(http.Header)
header := net.ProcessHeader(requestHeader, link.Header)
header = http_range.ApplyRangeToHttpHeader(httpRange, header)
response, err := net.RequestHttp(ctx, "GET", header, link.URL)
if err != nil { if err != nil {
if response == nil { if _, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok {
return nil, fmt.Errorf("http request failure, err:%s", err)
}
return nil, err return nil, err
} }
if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent || return nil, fmt.Errorf("http request failure, err:%w", err)
checkContentRange(&response.Header, r.Start) { }
if httpRange.Start == 0 && (httpRange.Length == -1 || httpRange.Length == size) || response.StatusCode == http.StatusPartialContent ||
checkContentRange(&response.Header, httpRange.Start) {
return response.Body, nil return response.Body, nil
} else if response.StatusCode == http.StatusOK { } else if response.StatusCode == http.StatusOK {
log.Warnf("remote http server not supporting range request, expect low perfromace!") log.Warnf("remote http server not supporting range request, expect low perfromace!")
readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length) readCloser, err := net.GetRangedHttpReader(response.Body, httpRange.Start, httpRange.Length)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return readCloser, nil return readCloser, nil
} }
return response.Body, nil return response.Body, nil
} }
resultRangeReadCloser := model.RangeReadCloser{RangeReader: rangeReaderFunc} return RateLimitRangeReaderFunc(rangeReader), nil
return &resultRangeReadCloser, nil
} }
func RequestRangedHttp(ctx context.Context, link *model.Link, offset, length int64) (*http.Response, error) { func GetRangeReaderFromMFile(size int64, file model.File) RangeReaderFunc {
header := net.ProcessHeader(nil, link.Header) return func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
header = http_range.ApplyRangeToHttpHeader(http_range.Range{Start: offset, Length: length}, header) length := httpRange.Length
if length < 0 || httpRange.Start+length > size {
return net.RequestHttp(ctx, "GET", header, link.URL) length = size - httpRange.Start
}
return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil
}
} }
// 139 cloud does not properly return 206 http status code, add a hack here // 139 cloud does not properly return 206 http status code, add a hack here

View File

@ -73,6 +73,8 @@ type call[T any] struct {
type Group[T any] struct { type Group[T any] struct {
mu sync.Mutex // protects m mu sync.Mutex // protects m
m map[string]*call[T] // lazily initialized m map[string]*call[T] // lazily initialized
Remember bool
} }
// Result holds the results of Do, so they can be passed // Result holds the results of Do, so they can be passed
@ -156,7 +158,7 @@ func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
c.wg.Done() c.wg.Done()
if g.m[key] == c { if !g.Remember && g.m[key] == c {
delete(g.m, key) delete(g.m, key)
} }

3
pkg/singleflight/var.go Normal file
View File

@ -0,0 +1,3 @@
package singleflight
var ErrorGroup Group[error]

View File

@ -153,46 +153,99 @@ func Retry(attempts int, sleep time.Duration, f func() error) (err error) {
type ClosersIF interface { type ClosersIF interface {
io.Closer io.Closer
Add(closer io.Closer) Add(closer io.Closer)
TryAdd(reader io.Reader) AddIfCloser(a any)
AddClosers(closers Closers)
GetClosers() Closers
} }
type Closers []io.Closer
type Closers struct { func (c *Closers) Close() error {
closers []io.Closer var errs []error
for _, closer := range *c {
if closer != nil {
errs = append(errs, closer.Close())
}
}
*c = (*c)[:0]
return errors.Join(errs...)
} }
func (c *Closers) Add(closer io.Closer) {
func (c *Closers) GetClosers() Closers { if closer != nil {
return *c *c = append(*c, closer)
}
}
func (c *Closers) AddIfCloser(a any) {
if closer, ok := a.(io.Closer); ok {
*c = append(*c, closer)
}
} }
var _ ClosersIF = (*Closers)(nil) var _ ClosersIF = (*Closers)(nil)
func (c *Closers) Close() error { func NewClosers(c ...io.Closer) Closers {
return Closers(c)
}
type SyncClosersIF interface {
ClosersIF
AcquireReference() bool
}
type SyncClosers struct {
closers []io.Closer
mu sync.Mutex
ref int
}
var _ SyncClosersIF = (*SyncClosers)(nil)
func (c *SyncClosers) AcquireReference() bool {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.closers) == 0 {
return false
}
c.ref++
log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, c.ref)
return true
}
func (c *SyncClosers) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
defer log.Debugf("SyncClosers.Close %p,ref=%d\n", c, c.ref)
if c.ref > 1 {
c.ref--
return nil
}
c.ref = 0
var errs []error var errs []error
for _, closer := range c.closers { for _, closer := range c.closers {
if closer != nil { if closer != nil {
errs = append(errs, closer.Close()) errs = append(errs, closer.Close())
} }
} }
c.closers = c.closers[:0]
return errors.Join(errs...) return errors.Join(errs...)
} }
func (c *Closers) Add(closer io.Closer) {
func (c *SyncClosers) Add(closer io.Closer) {
if closer != nil { if closer != nil {
c.mu.Lock()
c.closers = append(c.closers, closer) c.closers = append(c.closers, closer)
} c.mu.Unlock()
}
func (c *Closers) AddClosers(closers Closers) {
c.closers = append(c.closers, closers.closers...)
}
func (c *Closers) TryAdd(reader io.Reader) {
if closer, ok := reader.(io.Closer); ok {
c.closers = append(c.closers, closer)
} }
} }
func NewClosers(c ...io.Closer) Closers { func (c *SyncClosers) AddIfCloser(a any) {
return Closers{c} if closer, ok := a.(io.Closer); ok {
c.mu.Lock()
c.closers = append(c.closers, closer)
c.mu.Unlock()
}
}
func NewSyncClosers(c ...io.Closer) SyncClosers {
return SyncClosers{closers: c}
} }
type Ordered interface { type Ordered interface {

View File

@ -12,55 +12,34 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/net" "github.com/OpenListTeam/OpenList/v4/internal/net"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
) )
func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error {
if link.MFile != nil { if link.MFile != nil {
if clr, ok := link.MFile.(io.Closer); ok { attachHeader(w, file, link.Header)
defer clr.Close()
}
attachHeader(w, file)
contentType := link.Header.Get("Content-Type")
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile) http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile)
return nil return nil
} else if link.RangeReadCloser != nil {
attachHeader(w, file)
return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{
RangeReadCloserIF: link.RangeReadCloser,
Limiter: stream.ServerDownloadLimit,
})
} else if link.Concurrency > 0 || link.PartSize > 0 {
attachHeader(w, file)
size := file.GetSize()
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
requestHeader := ctx.Value("request_header")
if requestHeader == nil {
requestHeader = http.Header{}
} }
header := net.ProcessHeader(requestHeader.(http.Header), link.Header)
down := net.NewDownloader(func(d *net.Downloader) { if link.Concurrency > 0 || link.PartSize > 0 {
d.Concurrency = link.Concurrency attachHeader(w, file, link.Header)
d.PartSize = link.PartSize rrf, _ := stream.GetRangeReaderFromLink(file.GetSize(), link)
}) if link.RangeReader == nil {
req := &net.HttpRequestParams{ r = r.WithContext(context.WithValue(r.Context(), net.RequestHeaderKey{}, r.Header))
URL: link.URL,
Range: httpRange,
Size: size,
HeaderRef: header,
} }
rc, err := down.Download(ctx, req) return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{
return rc, err RangeReader: rrf,
}
return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{
RangeReadCloserIF: &model.RangeReadCloser{RangeReader: rangeReader},
Limiter: stream.ServerDownloadLimit,
}) })
} else { }
if link.RangeReader != nil {
attachHeader(w, file, link.Header)
return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{
RangeReader: link.RangeReader,
})
}
//transparent proxy //transparent proxy
header := net.ProcessHeader(r.Header, link.Header) header := net.ProcessHeader(r.Header, link.Header)
res, err := net.RequestHttp(r.Context(), r.Method, header, link.URL) res, err := net.RequestHttp(r.Context(), r.Method, header, link.URL)
@ -80,13 +59,16 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
Ctx: r.Context(), Ctx: r.Context(),
}) })
return err return err
}
} }
func attachHeader(w http.ResponseWriter, file model.Obj) { func attachHeader(w http.ResponseWriter, file model.Obj, header http.Header) {
fileName := file.GetName() fileName := file.GetName()
w.Header().Set("Content-Disposition", utils.GenerateContentDisposition(fileName)) w.Header().Set("Content-Disposition", utils.GenerateContentDisposition(fileName))
w.Header().Set("Content-Type", utils.GetMimeType(fileName)) w.Header().Set("Content-Type", utils.GetMimeType(fileName))
w.Header().Set("Etag", GetEtag(file)) w.Header().Set("Etag", GetEtag(file))
contentType := header.Get("Content-Type")
if len(contentType) > 0 {
w.Header().Set("Content-Type", contentType)
}
} }
func GetEtag(file model.Obj) string { func GetEtag(file model.Obj) string {
hash := "" hash := ""
@ -106,12 +88,12 @@ func ProxyRange(ctx context.Context, link *model.Link, size int64) {
if link.MFile != nil { if link.MFile != nil {
return return
} }
if link.RangeReadCloser == nil && !strings.HasPrefix(link.URL, GetApiUrl(ctx)+"/") { if link.RangeReader == nil && !strings.HasPrefix(link.URL, GetApiUrl(ctx)+"/") {
var rrc, err = stream.GetRangeReadCloserFromLink(size, link) rrf, err := stream.GetRangeReaderFromLink(size, link)
if err != nil { if err != nil {
return return
} }
link.RangeReadCloser = rrc link.RangeReader = rrf
} }
} }

View File

@ -45,12 +45,12 @@ func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownl
if err != nil { if err != nil {
return nil, err return nil, err
} }
fileStream := stream.FileStream{ ss, err := stream.NewSeekableStream(&stream.FileStream{
Obj: obj, Obj: obj,
Ctx: ctx, Ctx: ctx,
} }, link)
ss, err := stream.NewSeekableStream(fileStream, link)
if err != nil { if err != nil {
_ = link.Close()
return nil, err return nil, err
} }
reader, err := stream.NewReadAtSeeker(ss, offset) reader, err := stream.NewReadAtSeeker(ss, offset)

View File

@ -321,7 +321,7 @@ func ArchiveDown(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
down(c, link) redirect(c, link)
} }
} }
@ -351,7 +351,7 @@ func ArchiveProxy(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
localProxy(c, link, file, storage.GetStorage().ProxyRange) proxy(c, link, file, storage.GetStorage().ProxyRange)
} else { } else {
common.ErrorStrResp(c, "proxy not allowed", 403) common.ErrorStrResp(c, "proxy not allowed", 403)
return return

View File

@ -2,8 +2,8 @@ package handles
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io"
stdpath "path" stdpath "path"
"strconv" "strconv"
"strings" "strings"
@ -12,6 +12,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/fs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/net"
"github.com/OpenListTeam/OpenList/v4/internal/setting" "github.com/OpenListTeam/OpenList/v4/internal/setting"
"github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/internal/sign"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
@ -44,7 +45,7 @@ func Down(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
down(c, link) redirect(c, link)
} }
} }
@ -77,22 +78,15 @@ func Proxy(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
localProxy(c, link, file, storage.GetStorage().ProxyRange) proxy(c, link, file, storage.GetStorage().ProxyRange)
} else { } else {
common.ErrorStrResp(c, "proxy not allowed", 403) common.ErrorStrResp(c, "proxy not allowed", 403)
return return
} }
} }
func down(c *gin.Context, link *model.Link) { func redirect(c *gin.Context, link *model.Link) {
if clr, ok := link.MFile.(io.Closer); ok { defer link.Close()
defer func(clr io.Closer) {
err := clr.Close()
if err != nil {
log.Errorf("close link data error: %v", err)
}
}(clr)
}
var err error var err error
c.Header("Referrer-Policy", "no-referrer") c.Header("Referrer-Policy", "no-referrer")
c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
@ -110,7 +104,8 @@ func down(c *gin.Context, link *model.Link) {
c.Redirect(302, link.URL) c.Redirect(302, link.URL)
} }
func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange bool) { func proxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange bool) {
defer link.Close()
var err error var err error
if link.URL != "" && setting.GetBool(conf.ForwardDirectLinkParams) { if link.URL != "" && setting.GetBool(conf.ForwardDirectLinkParams) {
query := c.Request.URL.Query() query := c.Request.URL.Query()
@ -160,9 +155,13 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo
} }
if Writer.IsWritten() { if Writer.IsWritten() {
log.Errorf("%s %s local proxy error: %+v", c.Request.Method, c.Request.URL.Path, err) log.Errorf("%s %s local proxy error: %+v", c.Request.Method, c.Request.URL.Path, err)
} else {
if statusCode, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok {
common.ErrorResp(c, err, int(statusCode), true)
} else { } else {
common.ErrorResp(c, err, 500, true) common.ErrorResp(c, err, 500, true)
} }
}
} }
// TODO need optimize // TODO need optimize

View File

@ -2,7 +2,6 @@ package handles
import ( import (
"fmt" "fmt"
"io"
stdpath "path" stdpath "path"
"github.com/OpenListTeam/OpenList/v4/internal/task" "github.com/OpenListTeam/OpenList/v4/internal/task"
@ -17,7 +16,6 @@ import (
"github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus"
) )
type MkdirOrLinkReq struct { type MkdirOrLinkReq struct {
@ -376,7 +374,7 @@ func Link(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
if storage.Config().OnlyLocal { if storage.Config().NoLinkURL || storage.Config().OnlyLinkMFile {
common.SuccessResp(c, model.Link{ common.SuccessResp(c, model.Link{
URL: fmt.Sprintf("%s/p%s?d&sign=%s", URL: fmt.Sprintf("%s/p%s?d&sign=%s",
common.GetApiUrl(c), common.GetApiUrl(c),
@ -385,18 +383,11 @@ func Link(c *gin.Context) {
}) })
return return
} }
link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header}) link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, Redirect: true})
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
if clr, ok := link.MFile.(io.Closer); ok { defer link.Close()
defer func(clr io.Closer) {
err := clr.Close()
if err != nil {
log.Errorf("close link data error: %v", err)
}
}(clr)
}
common.SuccessResp(c, link) common.SuccessResp(c, link)
} }

View File

@ -315,6 +315,7 @@ func FsGet(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
defer link.Close()
rawURL = link.URL rawURL = link.URL
} }
} }

View File

@ -28,6 +28,12 @@ func getLastModified(c *gin.Context) time.Time {
} }
func FsStream(c *gin.Context) { func FsStream(c *gin.Context) {
defer func() {
if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
}
_ = c.Request.Body.Close()
}()
path := c.GetHeader("File-Path") path := c.GetHeader("File-Path")
path, err := url.PathUnescape(path) path, err := url.PathUnescape(path)
if err != nil { if err != nil {
@ -44,7 +50,6 @@ func FsStream(c *gin.Context) {
} }
if !overwrite { if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403) common.ErrorStrResp(c, "file exists", 403)
return return
} }
@ -90,15 +95,11 @@ func FsStream(c *gin.Context) {
} else { } else {
err = fs.PutDirectly(c, dir, s, true) err = fs.PutDirectly(c, dir, s, true)
} }
defer c.Request.Body.Close()
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
if t == nil { if t == nil {
if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
}
common.SuccessResp(c) common.SuccessResp(c)
return return
} }
@ -108,6 +109,12 @@ func FsStream(c *gin.Context) {
} }
func FsForm(c *gin.Context) { func FsForm(c *gin.Context) {
defer func() {
if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
}
_ = c.Request.Body.Close()
}()
path := c.GetHeader("File-Path") path := c.GetHeader("File-Path")
path, err := url.PathUnescape(path) path, err := url.PathUnescape(path)
if err != nil { if err != nil {
@ -124,7 +131,6 @@ func FsForm(c *gin.Context) {
} }
if !overwrite { if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403) common.ErrorStrResp(c, "file exists", 403)
return return
} }
@ -164,7 +170,7 @@ func FsForm(c *gin.Context) {
if len(mimetype) == 0 { if len(mimetype) == 0 {
mimetype = utils.GetMimeType(name) mimetype = utils.GetMimeType(name)
} }
s := stream.FileStream{ s := &stream.FileStream{
Obj: &model.Object{ Obj: &model.Object{
Name: name, Name: name,
Size: file.Size, Size: file.Size,
@ -180,9 +186,9 @@ func FsForm(c *gin.Context) {
s.Reader = struct { s.Reader = struct {
io.Reader io.Reader
}{f} }{f}
t, err = fs.PutAsTask(c, dir, &s) t, err = fs.PutAsTask(c, dir, s)
} else { } else {
err = fs.PutDirectly(c, dir, &s, true) err = fs.PutDirectly(c, dir, s, true)
} }
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)

View File

@ -142,7 +142,7 @@ func (b *s3Backend) HeadObject(ctx context.Context, bucketName, objectName strin
} }
// GetObject fetchs the object from the filesystem. // GetObject fetchs the object from the filesystem.
func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string, rangeRequest *gofakes3.ObjectRangeRequest) (obj *gofakes3.Object, err error) { func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string, rangeRequest *gofakes3.ObjectRangeRequest) (s3Obj *gofakes3.Object, err error) {
bucket, err := getBucketByName(bucketName) bucket, err := getBucketByName(bucketName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -164,6 +164,11 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
if s3Obj == nil {
_ = link.Close()
}
}()
size := file.GetSize() size := file.GetSize()
rnge, err := rangeRequest.Range(size) rnge, err := rangeRequest.Range(size)
@ -171,50 +176,20 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
return nil, err return nil, err
} }
if link.RangeReadCloser == nil && link.MFile == nil && len(link.URL) == 0 { rrf, err := stream.GetRangeReaderFromLink(size, link)
if err != nil {
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3") return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3")
} }
var rdr io.ReadCloser var rd io.Reader
length := int64(-1)
start := int64(0)
if rnge != nil { if rnge != nil {
start, length = rnge.Start, rnge.Length rd, err = rrf.RangeRead(ctx, http_range.Range(*rnge))
} else {
rd, err = rrf.RangeRead(ctx, http_range.Range{Length: -1})
} }
// 参考 server/common/proxy.go
if link.MFile != nil {
_, err := link.MFile.Seek(start, io.SeekStart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rdr2, ok := link.MFile.(io.ReadCloser); ok {
rdr = rdr2
} else {
rdr = io.NopCloser(link.MFile)
}
} else {
remoteFileSize := file.GetSize()
if length >= 0 && start+length >= remoteFileSize {
length = -1
}
rrc := link.RangeReadCloser
if len(link.URL) > 0 {
var converted, err = stream.GetRangeReadCloserFromLink(remoteFileSize, link)
if err != nil {
return nil, err
}
rrc = converted
}
if rrc != nil {
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: start, Length: length})
if err != nil {
return nil, err
}
rdr = utils.ReadCloser{Reader: remoteReader, Closer: rrc}
} else {
return nil, errs.NotSupport
}
}
meta := map[string]string{ meta := map[string]string{
"Last-Modified": node.ModTime().Format(timeFormat), "Last-Modified": node.ModTime().Format(timeFormat),
@ -236,7 +211,7 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
Metadata: meta, Metadata: meta,
Size: size, Size: size,
Range: rnge, Range: rnge,
Contents: rdr, Contents: utils.ReadCloser{Reader: rd, Closer: link},
}, nil }, nil
} }
@ -318,11 +293,11 @@ func (b *s3Backend) PutObject(
return result, err return result, err
} }
if err := stream.Close(); err != nil { // if err := stream.Close(); err != nil {
// remove file when close error occurred (FsPutErr) // // remove file when close error occurred (FsPutErr)
_ = fs.Remove(ctx, fp) // _ = fs.Remove(ctx, fp)
return result, err // return result, err
} // }
b.meta.Store(fp, meta) b.meta.Store(fp, meta)

View File

@ -9,6 +9,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -16,6 +17,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/internal/net"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
@ -245,11 +247,15 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
defer link.Close()
if storage.GetStorage().ProxyRange { if storage.GetStorage().ProxyRange {
common.ProxyRange(ctx, link, fi.GetSize()) common.ProxyRange(ctx, link, fi.GetSize())
} }
err = common.Proxy(w, r, link, fi) err = common.Proxy(w, r, link, fi)
if err != nil { if err != nil {
if statusCode, ok := errors.Unwrap(err).(net.ErrorHttpStatusCode); ok {
return int(statusCode), err
}
return http.StatusInternalServerError, fmt.Errorf("webdav proxy error: %+v", err) return http.StatusInternalServerError, fmt.Errorf("webdav proxy error: %+v", err)
} }
} else if storage.GetStorage().WebdavProxy() && downProxyUrl != "" { } else if storage.GetStorage().WebdavProxy() && downProxyUrl != "" {
@ -264,6 +270,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
defer link.Close()
http.Redirect(w, r, link.URL, http.StatusFound) http.Redirect(w, r, link.URL, http.StatusFound)
} }
return 0, nil return 0, nil
@ -305,6 +312,12 @@ func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status i
} }
func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, err error) { func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, err error) {
defer func() {
if n, _ := io.ReadFull(r.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, r.Body)
}
_ = r.Body.Close()
}()
reqPath, status, err := h.stripPrefix(r.URL.Path) reqPath, status, err := h.stripPrefix(r.URL.Path)
if err != nil { if err != nil {
return status, err return status, err
@ -344,8 +357,6 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int,
return http.StatusNotFound, err return http.StatusNotFound, err
} }
_ = r.Body.Close()
_ = fsStream.Close()
// TODO(rost): Returning 405 Method Not Allowed might not be appropriate. // TODO(rost): Returning 405 Method Not Allowed might not be appropriate.
if err != nil { if err != nil {
return http.StatusMethodNotAllowed, err return http.StatusMethodNotAllowed, err