refactor: optimize stream, link, and resource management (#486)

* refactor: optimize stream, link, and resource management

* Link.MFile改为io.ReadSeeker类型

* fix (crypt): read on closed response body

* chore

* chore

* chore
This commit is contained in:
j2rong4cn
2025-07-03 10:39:34 +08:00
committed by GitHub
parent 8e19a0fb07
commit ffb6c2a180
26 changed files with 180 additions and 275 deletions

View File

@ -113,6 +113,7 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
for _, dst := range dsts { for _, dst := range dsts {
link, err := d.link(ctx, dst, sub, args) link, err := d.link(ctx, dst, sub, args)
if err == nil { if err == nil {
link.Expiration = nil // 去除非必要缓存d.link里op.Lin有缓存
if !args.Redirect && len(link.URL) > 0 { if !args.Redirect && len(link.URL) > 0 {
// 正常情况下 多并发 仅支持返回URL的驱动 // 正常情况下 多并发 仅支持返回URL的驱动
// alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发 // alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发

View File

@ -254,43 +254,46 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if remoteLink.RangeReadCloser == nil && remoteLink.MFile == nil && len(remoteLink.URL) == 0 { if remoteLink.RangeReadCloser == nil && remoteLink.MFile == nil && len(remoteLink.URL) == 0 {
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion") return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion")
} }
resultRangeReadCloser := &model.RangeReadCloser{}
resultRangeReadCloser.TryAdd(remoteLink.MFile)
if remoteLink.RangeReadCloser != nil {
resultRangeReadCloser.AddClosers(remoteLink.RangeReadCloser.GetClosers())
}
remoteFileSize := remoteFile.GetSize() remoteFileSize := remoteFile.GetSize()
remoteClosers := utils.EmptyClosers()
rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) { rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) {
length := underlyingLength length := underlyingLength
if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize { if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize {
length = -1 length = -1
} }
rrc := remoteLink.RangeReadCloser
if len(remoteLink.URL) > 0 {
var converted, err = stream.GetRangeReadCloserFromLink(remoteFileSize, remoteLink)
if err != nil {
return nil, err
}
rrc = converted
}
if rrc != nil {
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length})
remoteClosers.AddClosers(rrc.GetClosers())
if err != nil {
return nil, err
}
return remoteReader, nil
}
if remoteLink.MFile != nil { if remoteLink.MFile != nil {
_, err := remoteLink.MFile.Seek(underlyingOffset, io.SeekStart) _, err := remoteLink.MFile.Seek(underlyingOffset, io.SeekStart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
//keep reuse same MFile and close at last. //keep reuse same MFile and close at last.
remoteClosers.Add(remoteLink.MFile)
return io.NopCloser(remoteLink.MFile), nil return io.NopCloser(remoteLink.MFile), nil
} }
rrc := remoteLink.RangeReadCloser
if rrc == nil && len(remoteLink.URL) > 0 {
var err error
rrc, err = stream.GetRangeReadCloserFromLink(remoteFileSize, remoteLink)
if err != nil {
return nil, err
}
resultRangeReadCloser.AddClosers(rrc.GetClosers())
remoteLink.RangeReadCloser = rrc
}
if rrc != nil {
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length})
if err != nil {
return nil, err
}
return remoteReader, nil
}
return nil, errs.NotSupport return nil, errs.NotSupport
} }
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { resultRangeReadCloser.RangeReader = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length) readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length)
if err != nil { if err != nil {
return nil, err return nil, err
@ -298,14 +301,9 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
return readSeeker, nil return readSeeker, nil
} }
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers} return &model.Link{
resultLink := &model.Link{
RangeReadCloser: resultRangeReadCloser, RangeReadCloser: resultRangeReadCloser,
Expiration: remoteLink.Expiration, }, nil
}
return resultLink, nil
} }
func (d *Crypt) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { func (d *Crypt) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {

View File

@ -524,7 +524,6 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to cache file: %w", err) return nil, fmt.Errorf("failed to cache file: %w", err)
} }
defer tempFile.Close()
up(10.0) // 更新进度 up(10.0) // 更新进度
// 设置并行上传 // 设置并行上传
threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,

View File

@ -7,6 +7,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/jlaffaye/ftp" "github.com/jlaffaye/ftp"
) )
@ -66,7 +67,11 @@ func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
r := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) r := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize())
link := &model.Link{ link := &model.Link{
MFile: r, MFile: &stream.RateLimitFile{
File: r,
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
},
} }
return link, nil return link, nil
} }

View File

@ -256,9 +256,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
length = -1 length = -1
} }
if err != nil {
return nil, fmt.Errorf("open download file failed: %w", err)
}
oo := &openObject{ oo := &openObject{
ctx: ctx, ctx: ctx,
d: fileAddrs, d: fileAddrs,

View File

@ -96,7 +96,3 @@ type SteamFile struct {
func (s *SteamFile) Read(p []byte) (n int, err error) { func (s *SteamFile) Read(p []byte) (n int, err error) {
return s.file.Read(p) return s.file.Read(p)
} }
func (s *SteamFile) Close() error {
return s.file.Close()
}

View File

@ -242,7 +242,7 @@ func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
} }
link.MFile = open link.MFile = open
} else { } else {
link.MFile = model.NewNopMFile(bytes.NewReader(buf.Bytes())) link.MFile = bytes.NewReader(buf.Bytes())
//link.Header.Set("Content-Length", strconv.Itoa(buf.Len())) //link.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
} }
} else { } else {

View File

@ -184,9 +184,6 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, file model.FileS
if err != nil { if err != nil {
return err return err
} }
defer func() {
_ = tempFile.Close()
}()
uploader := s3manager.NewUploader(s) uploader := s3manager.NewUploader(s)
if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize {
uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1)

View File

@ -2,7 +2,6 @@ package netease_music
import ( import (
"context" "context"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -11,7 +10,6 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/sign" "github.com/OpenListTeam/OpenList/v4/internal/sign"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/OpenList/v4/pkg/utils/random" "github.com/OpenListTeam/OpenList/v4/pkg/utils/random"
"github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/common"
@ -55,17 +53,8 @@ func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link {
} }
func (lrc *LyricObj) getLyricLink() *model.Link { func (lrc *LyricObj) getLyricLink() *model.Link {
reader := strings.NewReader(lrc.lyric)
return &model.Link{ return &model.Link{
RangeReadCloser: &model.RangeReadCloser{ MFile: strings.NewReader(lrc.lyric),
RangeReader: func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
if httpRange.Length < 0 {
return io.NopCloser(reader), nil
}
sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length)
return io.NopCloser(sr), nil
},
},
} }
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/pkg/sftp" "github.com/pkg/sftp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -62,10 +63,13 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
link := &model.Link{ return &model.Link{
MFile: remoteFile, MFile: &stream.RateLimitFile{
} File: remoteFile,
return link, nil Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
},
}, nil
} }
func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {

View File

@ -8,6 +8,7 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/hirochachacha/go-smb2" "github.com/hirochachacha/go-smb2"
@ -79,11 +80,14 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
d.cleanLastConnTime() d.cleanLastConnTime()
return nil, err return nil, err
} }
link := &model.Link{
MFile: remoteFile,
}
d.updateLastConnTime() d.updateLastConnTime()
return link, nil return &model.Link{
MFile: &stream.RateLimitFile{
File: remoteFile,
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
},
}, nil
} }
func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {

View File

@ -114,7 +114,7 @@ func (d *Strm) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]
func (d *Strm) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { func (d *Strm) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
link := d.getLink(ctx, file.GetPath()) link := d.getLink(ctx, file.GetPath())
return &model.Link{ return &model.Link{
MFile: model.NewNopMFile(strings.NewReader(link)), MFile: strings.NewReader(link),
}, nil }, nil
} }

View File

@ -18,8 +18,9 @@ import (
) )
type VolumeFile struct { type VolumeFile struct {
stream.SStreamReadAtSeeker model.File
name string name string
ss model.FileStreamer
} }
func (v *VolumeFile) Name() string { func (v *VolumeFile) Name() string {
@ -27,7 +28,7 @@ func (v *VolumeFile) Name() string {
} }
func (v *VolumeFile) Size() int64 { func (v *VolumeFile) Size() int64 {
return v.SStreamReadAtSeeker.GetRawStream().GetSize() return v.ss.GetSize()
} }
func (v *VolumeFile) Mode() fs.FileMode { func (v *VolumeFile) Mode() fs.FileMode {
@ -35,7 +36,7 @@ func (v *VolumeFile) Mode() fs.FileMode {
} }
func (v *VolumeFile) ModTime() time.Time { func (v *VolumeFile) ModTime() time.Time {
return v.SStreamReadAtSeeker.GetRawStream().ModTime() return v.ss.ModTime()
} }
func (v *VolumeFile) IsDir() bool { func (v *VolumeFile) IsDir() bool {
@ -74,7 +75,7 @@ func makeOpts(ss []*stream.SeekableStream) (string, rardecode.Option, error) {
} }
fileName := "file.rar" fileName := "file.rar"
fsys := &VolumeFs{parts: map[string]*VolumeFile{ fsys := &VolumeFs{parts: map[string]*VolumeFile{
fileName: {SStreamReadAtSeeker: reader, name: fileName}, fileName: {File: reader, name: fileName},
}} }}
return fileName, rardecode.FileSystem(fsys), nil return fileName, rardecode.FileSystem(fsys), nil
} else { } else {
@ -85,7 +86,7 @@ func makeOpts(ss []*stream.SeekableStream) (string, rardecode.Option, error) {
return "", nil, err return "", nil, err
} }
fileName := fmt.Sprintf("file.part%d.rar", i+1) fileName := fmt.Sprintf("file.part%d.rar", i+1)
parts[fileName] = &VolumeFile{SStreamReadAtSeeker: reader, name: fileName} parts[fileName] = &VolumeFile{File: reader, name: fileName, ss: s}
} }
return "file.part1.rar", rardecode.FileSystem(&VolumeFs{parts: parts}), nil return "file.part1.rar", rardecode.FileSystem(&VolumeFs{parts: parts}), nil
} }

View File

@ -27,10 +27,9 @@ type Link struct {
URL string `json:"url"` // most common way URL string `json:"url"` // most common way
Header http.Header `json:"header"` // needed header (for url) Header http.Header `json:"header"` // needed header (for url)
RangeReadCloser RangeReadCloserIF `json:"-"` // recommended way if can't use URL RangeReadCloser RangeReadCloserIF `json:"-"` // recommended way if can't use URL
MFile File `json:"-"` // best for local,smb... file system, which exposes MFile MFile io.ReadSeeker `json:"-"` // best for local,smb... file system, which exposes MFile
Expiration *time.Duration // local cache expire Duration Expiration *time.Duration // local cache expire Duration
IPCacheKey bool `json:"-"` // add ip to cache key
//for accelerating request, use multi-thread downloading //for accelerating request, use multi-thread downloading
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`

View File

@ -7,19 +7,4 @@ type File interface {
io.Reader io.Reader
io.ReaderAt io.ReaderAt
io.Seeker io.Seeker
io.Closer
}
type NopMFileIF interface {
io.Reader
io.ReaderAt
io.Seeker
}
type NopMFile struct {
NopMFileIF
}
func (NopMFile) Close() error { return nil }
func NewNopMFile(r NopMFileIF) File {
return NopMFile{r}
} }

View File

@ -182,11 +182,10 @@ func (d *downloader) download() (io.ReadCloser, error) {
defer d.m.Unlock() defer d.m.Unlock()
if closeFunc != nil { if closeFunc != nil {
d.concurrencyFinish() d.concurrencyFinish()
err := closeFunc() err = closeFunc()
closeFunc = nil closeFunc = nil
return err
} }
return nil return err
}) })
return resp.Body, nil return resp.Body, nil
} }
@ -272,24 +271,30 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error {
// when the final reader Close, we interrupt // when the final reader Close, we interrupt
func (d *downloader) interrupt() error { func (d *downloader) interrupt() error {
if d.written != d.params.Range.Length { d.m.Lock()
defer d.m.Unlock()
err := d.err
if err == nil && d.written != d.params.Range.Length {
log.Debugf("Downloader interrupt before finish") log.Debugf("Downloader interrupt before finish")
if d.getErr() == nil { err := fmt.Errorf("interrupted")
d.setErr(fmt.Errorf("interrupted")) d.err = err
}
} }
d.cancel(d.err) if d.chunkChannel != nil {
defer func() { d.cancel(err)
close(d.chunkChannel) close(d.chunkChannel)
d.chunkChannel = nil
for _, buf := range d.bufs { for _, buf := range d.bufs {
buf.Close() buf.Close()
} }
d.bufs = nil
if d.concurrency > 0 { if d.concurrency > 0 {
d.concurrency = -d.concurrency d.concurrency = -d.concurrency
} }
log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency) log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency)
}() } else {
return d.err log.Debug("close of closed channel")
}
return err
} }
func (d *downloader) getBuf(id int) (b *Buf) { func (d *downloader) getBuf(id int) (b *Buf) {
return d.bufs[id%len(d.bufs)] return d.bufs[id%len(d.bufs)]

View File

@ -62,8 +62,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
} }
baseName, ext, found := strings.Cut(obj.GetName(), ".") baseName, ext, found := strings.Cut(obj.GetName(), ".")
if !found { if !found {
if l.MFile != nil { if clr, ok := l.MFile.(io.Closer); ok {
_ = l.MFile.Close() _ = clr.Close()
} }
if l.RangeReadCloser != nil { if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close() _ = l.RangeReadCloser.Close()
@ -75,8 +75,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
var e error var e error
partExt, t, e = tool.GetArchiveTool(stdpath.Ext(obj.GetName())) partExt, t, e = tool.GetArchiveTool(stdpath.Ext(obj.GetName()))
if e != nil { if e != nil {
if l.MFile != nil { if clr, ok := l.MFile.(io.Closer); ok {
_ = l.MFile.Close() _ = clr.Close()
} }
if l.RangeReadCloser != nil { if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close() _ = l.RangeReadCloser.Close()
@ -86,8 +86,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
} }
ss, err := stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: obj}, l) ss, err := stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: obj}, l)
if err != nil { if err != nil {
if l.MFile != nil { if clr, ok := l.MFile.(io.Closer); ok {
_ = l.MFile.Close() _ = clr.Close()
} }
if l.RangeReadCloser != nil { if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close() _ = l.RangeReadCloser.Close()
@ -109,8 +109,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st
} }
ss, err = stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: o}, l) ss, err = stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: o}, l)
if err != nil { if err != nil {
if l.MFile != nil { if clr, ok := l.MFile.(io.Closer); ok {
_ = l.MFile.Close() _ = clr.Close()
} }
if l.RangeReadCloser != nil { if l.RangeReadCloser != nil {
_ = l.RangeReadCloser.Close() _ = l.RangeReadCloser.Close()
@ -174,9 +174,6 @@ func getArchiveMeta(ctx context.Context, storage driver.Driver, path string, arg
if !storage.Config().NoCache { if !storage.Config().NoCache {
Expiration := time.Minute * time.Duration(storage.GetStorage().CacheExpiration) Expiration := time.Minute * time.Duration(storage.GetStorage().CacheExpiration)
archiveMetaProvider.Expiration = &Expiration archiveMetaProvider.Expiration = &Expiration
} else if ss[0].Link.MFile == nil {
// alias、crypt 驱动
archiveMetaProvider.Expiration = ss[0].Link.Expiration
} }
return obj, archiveMetaProvider, err return obj, archiveMetaProvider, err
} }
@ -401,9 +398,6 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
return nil, errors.Wrapf(err, "failed extract archive") return nil, errors.Wrapf(err, "failed extract archive")
} }
if link.Link.Expiration != nil { if link.Link.Expiration != nil {
if link.Link.IPCacheKey {
key = key + ":" + args.IP
}
extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration)) extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration))
} }
return link, nil return link, nil

View File

@ -268,9 +268,6 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
return nil, errors.Wrapf(err, "failed get link") return nil, errors.Wrapf(err, "failed get link")
} }
if link.Expiration != nil { if link.Expiration != nil {
if link.IPCacheKey {
key = key + ":" + args.IP
}
linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration))
} }
return link, nil return link, nil

View File

@ -135,6 +135,13 @@ func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) {
return return
} }
func (r *RateLimitFile) Close() error {
if c, ok := r.File.(io.Closer); ok {
return c.Close()
}
return nil
}
type RateLimitRangeReadCloser struct { type RateLimitRangeReadCloser struct {
model.RangeReadCloserIF model.RangeReadCloserIF
Limiter Limiter Limiter Limiter

View File

@ -81,10 +81,7 @@ func (f *FileStream) SetExist(obj model.Obj) {
// CacheFullInTempFile save all data into tmpFile. Not recommended since it wears disk, // CacheFullInTempFile save all data into tmpFile. Not recommended since it wears disk,
// and can't start upload until the file is written. It's not thread-safe! // and can't start upload until the file is written. It's not thread-safe!
func (f *FileStream) CacheFullInTempFile() (model.File, error) { func (f *FileStream) CacheFullInTempFile() (model.File, error) {
if f.tmpFile != nil { if file := f.GetFile(); file != nil {
return f.tmpFile, nil
}
if file, ok := f.Reader.(model.File); ok {
return file, nil return file, nil
} }
tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize()) tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize())
@ -117,33 +114,35 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
// 参考 internal/net/request.go // 参考 internal/net/request.go
httpRange.Length = f.GetSize() - httpRange.Start httpRange.Length = f.GetSize() - httpRange.Start
} }
var cache io.ReaderAt = f.GetFile()
if cache != nil {
return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil
}
size := httpRange.Start + httpRange.Length size := httpRange.Start + httpRange.Length
if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) {
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
} }
var cache io.ReaderAt = f.GetFile() if size <= InMemoryBufMaxSizeBytes {
if cache == nil { bufSize := min(size, f.GetSize())
if size <= InMemoryBufMaxSizeBytes { // 使用bytes.Buffer作为io.CopyBuffer的写入对象CopyBuffer会调用Buffer.ReadFrom
bufSize := min(size, f.GetSize()) // 即使被写入的数据量与Buffer.Cap一致Buffer也会扩大
// 使用bytes.Buffer作为io.CopyBuffer的写入对象CopyBuffer会调用Buffer.ReadFrom buf := make([]byte, bufSize)
// 即使被写入的数据量与Buffer.Cap一致Buffer也会扩大 n, err := io.ReadFull(f.Reader, buf)
buf := make([]byte, bufSize) if err != nil {
n, err := io.ReadFull(f.Reader, buf) return nil, err
if err != nil { }
return nil, err if n != int(bufSize) {
} return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n)
if n != int(bufSize) { }
return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) f.peekBuff = bytes.NewReader(buf)
} f.Reader = io.MultiReader(f.peekBuff, f.Reader)
f.peekBuff = bytes.NewReader(buf) cache = f.peekBuff
f.Reader = io.MultiReader(f.peekBuff, f.Reader) } else {
cache = f.peekBuff var err error
} else { cache, err = f.CacheFullInTempFile()
var err error if err != nil {
cache, err = f.CacheFullInTempFile() return nil, err
if err != nil {
return nil, err
}
} }
} }
return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil
@ -161,49 +160,34 @@ var _ model.FileStreamer = (*FileStream)(nil)
// the SeekableStream object and be closed together when the SeekableStream object is closed. // the SeekableStream object and be closed together when the SeekableStream object is closed.
type SeekableStream struct { type SeekableStream struct {
FileStream FileStream
Link *model.Link
// should have one of belows to support rangeRead // should have one of belows to support rangeRead
rangeReadCloser model.RangeReadCloserIF rangeReadCloser model.RangeReadCloserIF
mFile model.File
} }
func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) { func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) {
if len(fs.Mimetype) == 0 { if len(fs.Mimetype) == 0 {
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
} }
ss := &SeekableStream{FileStream: fs, Link: link} ss := &SeekableStream{FileStream: fs}
if ss.Reader != nil { if ss.Reader != nil {
result, ok := ss.Reader.(model.File) ss.TryAdd(ss.Reader)
if ok { return ss, nil
ss.mFile = result
ss.Closers.Add(result)
return ss, nil
}
} }
if ss.Link != nil { if link != nil {
if ss.Link.MFile != nil { if link.MFile != nil {
mFile := ss.Link.MFile ss.Closers.TryAdd(link.MFile)
if _, ok := mFile.(*os.File); !ok { ss.Reader = link.MFile
mFile = &RateLimitFile{
File: mFile,
Limiter: ServerDownloadLimit,
Ctx: fs.Ctx,
}
}
ss.mFile = mFile
ss.Reader = mFile
ss.Closers.Add(mFile)
return ss, nil return ss, nil
} }
if ss.Link.RangeReadCloser != nil { if link.RangeReadCloser != nil {
ss.rangeReadCloser = &RateLimitRangeReadCloser{ ss.rangeReadCloser = &RateLimitRangeReadCloser{
RangeReadCloserIF: ss.Link.RangeReadCloser, RangeReadCloserIF: link.RangeReadCloser,
Limiter: ServerDownloadLimit, Limiter: ServerDownloadLimit,
} }
ss.Add(ss.rangeReadCloser) ss.Add(ss.rangeReadCloser)
return ss, nil return ss, nil
} }
if len(ss.Link.URL) > 0 { if len(link.URL) > 0 {
rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link)
if err != nil { if err != nil {
return nil, err return nil, err
@ -217,9 +201,6 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
return ss, nil return ss, nil
} }
} }
if fs.Reader != nil {
return ss, nil
}
return nil, fmt.Errorf("illegal seekableStream") return nil, fmt.Errorf("illegal seekableStream")
} }
@ -229,16 +210,10 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
// RangeRead is not thread-safe, pls use it in single thread only. // RangeRead is not thread-safe, pls use it in single thread only.
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
if httpRange.Length == -1 { if ss.tmpFile == nil && ss.rangeReadCloser != nil {
httpRange.Length = ss.GetSize() - httpRange.Start if httpRange.Length == -1 {
} httpRange.Length = ss.GetSize() - httpRange.Start
if ss.mFile != nil { }
return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil
}
if ss.tmpFile != nil {
return io.NewSectionReader(ss.tmpFile, httpRange.Start, httpRange.Length), nil
}
if ss.rangeReadCloser != nil {
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange) rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange)
if err != nil { if err != nil {
return nil, err return nil, err
@ -272,11 +247,8 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
} }
func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
if ss.tmpFile != nil { if file := ss.GetFile(); file != nil {
return ss.tmpFile, nil return file, nil
}
if ss.mFile != nil {
return ss.mFile, nil
} }
tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) tmpF, err := utils.CreateTempFile(ss, ss.GetSize())
if err != nil { if err != nil {
@ -288,16 +260,6 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
return tmpF, nil return tmpF, nil
} }
func (ss *SeekableStream) GetFile() model.File {
if ss.tmpFile != nil {
return ss.tmpFile
}
if ss.mFile != nil {
return ss.mFile
}
return nil
}
func (f *FileStream) SetTmpFile(r *os.File) { func (f *FileStream) SetTmpFile(r *os.File) {
f.Add(r) f.Add(r)
f.tmpFile = r f.tmpFile = r
@ -342,11 +304,6 @@ func (r *ReaderUpdatingProgress) Close() error {
return r.Reader.Close() return r.Reader.Close()
} }
type SStreamReadAtSeeker interface {
model.File
GetRawStream() *SeekableStream
}
type readerCur struct { type readerCur struct {
reader io.Reader reader io.Reader
cur int64 cur int64
@ -407,7 +364,7 @@ func (r *headCache) Close() error {
} }
func (r *RangeReadReadAtSeeker) InitHeadCache() { func (r *RangeReadReadAtSeeker) InitHeadCache() {
if r.ss.Link.MFile == nil && r.masterOff == 0 { if r.ss.GetFile() == nil && r.masterOff == 0 {
reader := r.readers[0] reader := r.readers[0]
r.readers = r.readers[1:] r.readers = r.readers[1:]
r.headCache = &headCache{readerCur: reader} r.headCache = &headCache{readerCur: reader}
@ -415,13 +372,13 @@ func (r *RangeReadReadAtSeeker) InitHeadCache() {
} }
} }
func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (SStreamReadAtSeeker, error) { func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (model.File, error) {
if ss.mFile != nil { if ss.GetFile() != nil {
_, err := ss.mFile.Seek(offset, io.SeekStart) _, err := ss.GetFile().Seek(offset, io.SeekStart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &FileReadAtSeeker{ss: ss}, nil return ss.GetFile(), nil
} }
r := &RangeReadReadAtSeeker{ r := &RangeReadReadAtSeeker{
ss: ss, ss: ss,
@ -454,10 +411,6 @@ func NewMultiReaderAt(ss []*SeekableStream) (readerutil.SizeReaderAt, error) {
return readerutil.NewMultiReaderAt(readers...), nil return readerutil.NewMultiReaderAt(readers...), nil
} }
func (r *RangeReadReadAtSeeker) GetRawStream() *SeekableStream {
return r.ss
}
func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) { func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) {
var rc *readerCur var rc *readerCur
for _, reader := range r.readers { for _, reader := range r.readers {
@ -562,31 +515,3 @@ func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) {
r.masterOff += int64(n) r.masterOff += int64(n)
return n, err return n, err
} }
func (r *RangeReadReadAtSeeker) Close() error {
return r.ss.Close()
}
type FileReadAtSeeker struct {
ss *SeekableStream
}
func (f *FileReadAtSeeker) GetRawStream() *SeekableStream {
return f.ss
}
func (f *FileReadAtSeeker) Read(p []byte) (n int, err error) {
return f.ss.mFile.Read(p)
}
func (f *FileReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) {
return f.ss.mFile.ReadAt(p, off)
}
func (f *FileReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
return f.ss.mFile.Seek(offset, whence)
}
func (f *FileReadAtSeeker) Close() error {
return f.ss.Close()
}

View File

@ -153,6 +153,7 @@ func Retry(attempts int, sleep time.Duration, f func() error) (err error) {
type ClosersIF interface { type ClosersIF interface {
io.Closer io.Closer
Add(closer io.Closer) Add(closer io.Closer)
TryAdd(reader io.Reader)
AddClosers(closers Closers) AddClosers(closers Closers)
GetClosers() Closers GetClosers() Closers
} }
@ -177,16 +178,19 @@ func (c *Closers) Close() error {
return errors.Join(errs...) return errors.Join(errs...)
} }
func (c *Closers) Add(closer io.Closer) { func (c *Closers) Add(closer io.Closer) {
c.closers = append(c.closers, closer) if closer != nil {
c.closers = append(c.closers, closer)
}
} }
func (c *Closers) AddClosers(closers Closers) { func (c *Closers) AddClosers(closers Closers) {
c.closers = append(c.closers, closers.closers...) c.closers = append(c.closers, closers.closers...)
} }
func (c *Closers) TryAdd(reader io.Reader) {
func EmptyClosers() Closers { if closer, ok := reader.(io.Closer); ok {
return Closers{[]io.Closer{}} c.closers = append(c.closers, closer)
}
} }
func NewClosers(c ...io.Closer) Closers { func NewClosers(c ...io.Closer) Closers {
return Closers{c} return Closers{c}
} }

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"strings" "strings"
"maps" "maps"
@ -19,21 +18,15 @@ import (
func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error {
if link.MFile != nil { if link.MFile != nil {
defer link.MFile.Close() if clr, ok := link.MFile.(io.Closer); ok {
defer clr.Close()
}
attachHeader(w, file) attachHeader(w, file)
contentType := link.Header.Get("Content-Type") contentType := link.Header.Get("Content-Type")
if contentType != "" { if contentType != "" {
w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
} }
mFile := link.MFile http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile)
if _, ok := mFile.(*os.File); !ok {
mFile = &stream.RateLimitFile{
File: mFile,
Limiter: stream.ServerDownloadLimit,
Ctx: r.Context(),
}
}
http.ServeContent(w, r, file.GetName(), file.ModTime(), mFile)
return nil return nil
} else if link.RangeReadCloser != nil { } else if link.RangeReadCloser != nil {
attachHeader(w, file) attachHeader(w, file)

View File

@ -2,6 +2,7 @@ package ftp
import ( import (
"context" "context"
"io"
fs2 "io/fs" fs2 "io/fs"
"net/http" "net/http"
"os" "os"
@ -13,13 +14,13 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/op"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/common"
ftpserver "github.com/fclairamb/ftpserverlib"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type FileDownloadProxy struct { type FileDownloadProxy struct {
ftpserver.FileTransfer model.File
reader stream.SStreamReadAtSeeker io.Closer
ctx context.Context
} }
func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownloadProxy, error) { func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownloadProxy, error) {
@ -57,15 +58,24 @@ func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownl
_ = ss.Close() _ = ss.Close()
return nil, err return nil, err
} }
return &FileDownloadProxy{reader: reader}, nil return &FileDownloadProxy{File: reader, Closer: ss, ctx: ctx}, nil
} }
func (f *FileDownloadProxy) Read(p []byte) (n int, err error) { func (f *FileDownloadProxy) Read(p []byte) (n int, err error) {
n, err = f.reader.Read(p) n, err = f.File.Read(p)
if err != nil { if err != nil {
return return
} }
err = stream.ClientDownloadLimit.WaitN(f.reader.GetRawStream().Ctx, n) err = stream.ClientDownloadLimit.WaitN(f.ctx, n)
return
}
func (f *FileDownloadProxy) ReadAt(p []byte, off int64) (n int, err error) {
n, err = f.File.ReadAt(p, off)
if err != nil {
return
}
err = stream.ClientDownloadLimit.WaitN(f.ctx, n)
return return
} }
@ -73,14 +83,6 @@ func (f *FileDownloadProxy) Write(p []byte) (n int, err error) {
return 0, errs.NotSupport return 0, errs.NotSupport
} }
func (f *FileDownloadProxy) Seek(offset int64, whence int) (int64, error) {
return f.reader.Seek(offset, whence)
}
func (f *FileDownloadProxy) Close() error {
return f.reader.Close()
}
type OsFileInfoAdapter struct { type OsFileInfoAdapter struct {
obj model.Obj obj model.Obj
} }

View File

@ -85,15 +85,15 @@ func Proxy(c *gin.Context) {
} }
func down(c *gin.Context, link *model.Link) { func down(c *gin.Context, link *model.Link) {
var err error if clr, ok := link.MFile.(io.Closer); ok {
if link.MFile != nil { defer func(clr io.Closer) {
defer func(ReadSeekCloser io.ReadCloser) { err := clr.Close()
err := ReadSeekCloser.Close()
if err != nil { if err != nil {
log.Errorf("close data error: %s", err) log.Errorf("close link data error: %v", err)
} }
}(link.MFile) }(clr)
} }
var err error
c.Header("Referrer-Policy", "no-referrer") c.Header("Referrer-Policy", "no-referrer")
c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
if setting.GetBool(conf.ForwardDirectLinkParams) { if setting.GetBool(conf.ForwardDirectLinkParams) {

View File

@ -390,14 +390,13 @@ func Link(c *gin.Context) {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
if link.MFile != nil { if clr, ok := link.MFile.(io.Closer); ok {
defer func(ReadSeekCloser io.ReadCloser) { defer func(clr io.Closer) {
err := ReadSeekCloser.Close() err := clr.Close()
if err != nil { if err != nil {
log.Errorf("close link data error: %v", err) log.Errorf("close link data error: %v", err)
} }
}(link.MFile) }(clr)
} }
common.SuccessResp(c, link) common.SuccessResp(c, link)
return
} }

View File

@ -187,7 +187,11 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
if err != nil { if err != nil {
return nil, err return nil, err
} }
rdr = link.MFile if rdr2, ok := link.MFile.(io.ReadCloser); ok {
rdr = rdr2
} else {
rdr = io.NopCloser(link.MFile)
}
} else { } else {
remoteFileSize := file.GetSize() remoteFileSize := file.GetSize()
if length >= 0 && start+length >= remoteFileSize { if length >= 0 && start+length >= remoteFileSize {