From ffb6c2a18066e4d13c5e61a147fd009019a86b55 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Thu, 3 Jul 2025 10:39:34 +0800 Subject: [PATCH] refactor: optimize stream, link, and resource management (#486) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: optimize stream, link, and resource management * Link.MFile改为io.ReadSeeker类型 * fix (crypt): read on closed response body * chore * chore * chore --- drivers/alias/driver.go | 1 + drivers/crypt/driver.go | 52 ++++----- drivers/doubao/util.go | 1 - drivers/ftp/driver.go | 7 +- drivers/halalcloud/driver.go | 3 - drivers/halalcloud/types.go | 4 - drivers/local/driver.go | 2 +- drivers/mediatrack/driver.go | 3 - drivers/netease_music/types.go | 13 +-- drivers/sftp/driver.go | 12 +- drivers/smb/driver.go | 12 +- drivers/strm/driver.go | 2 +- internal/archive/rardecode/utils.go | 11 +- internal/model/args.go | 3 +- internal/model/file.go | 15 --- internal/net/request.go | 27 +++-- internal/op/archive.go | 22 ++-- internal/op/fs.go | 3 - internal/stream/limit.go | 7 ++ internal/stream/stream.go | 169 ++++++++-------------------- pkg/utils/io.go | 14 ++- server/common/proxy.go | 15 +-- server/ftp/fsread.go | 30 ++--- server/handles/down.go | 12 +- server/handles/fsmanage.go | 9 +- server/s3/backend.go | 6 +- 26 files changed, 180 insertions(+), 275 deletions(-) diff --git a/drivers/alias/driver.go b/drivers/alias/driver.go index a5569b68..c2dcfb95 100644 --- a/drivers/alias/driver.go +++ b/drivers/alias/driver.go @@ -113,6 +113,7 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( for _, dst := range dsts { link, err := d.link(ctx, dst, sub, args) if err == nil { + link.Expiration = nil // 去除非必要缓存,d.link里op.Lin有缓存 if !args.Redirect && len(link.URL) > 0 { // 正常情况下 多并发 仅支持返回URL的驱动 // alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发 diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index 4f185ff3..def55535 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -254,43 +254,46 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if remoteLink.RangeReadCloser == nil && remoteLink.MFile == nil && len(remoteLink.URL) == 0 { return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion") } + resultRangeReadCloser := &model.RangeReadCloser{} + resultRangeReadCloser.TryAdd(remoteLink.MFile) + if remoteLink.RangeReadCloser != nil { + resultRangeReadCloser.AddClosers(remoteLink.RangeReadCloser.GetClosers()) + } remoteFileSize := remoteFile.GetSize() - remoteClosers := utils.EmptyClosers() rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) { length := underlyingLength if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize { length = -1 } - rrc := remoteLink.RangeReadCloser - if len(remoteLink.URL) > 0 { - var converted, err = stream.GetRangeReadCloserFromLink(remoteFileSize, remoteLink) - if err != nil { - return nil, err - } - rrc = converted - } - if rrc != nil { - remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length}) - remoteClosers.AddClosers(rrc.GetClosers()) - if err != nil { - return nil, err - } - return remoteReader, nil - } if remoteLink.MFile != nil { _, err := remoteLink.MFile.Seek(underlyingOffset, io.SeekStart) if err != nil { return nil, err } //keep reuse same MFile and close at last. - remoteClosers.Add(remoteLink.MFile) return io.NopCloser(remoteLink.MFile), nil } - + rrc := remoteLink.RangeReadCloser + if rrc == nil && len(remoteLink.URL) > 0 { + var err error + rrc, err = stream.GetRangeReadCloserFromLink(remoteFileSize, remoteLink) + if err != nil { + return nil, err + } + resultRangeReadCloser.AddClosers(rrc.GetClosers()) + remoteLink.RangeReadCloser = rrc + } + if rrc != nil { + remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length}) + if err != nil { + return nil, err + } + return remoteReader, nil + } return nil, errs.NotSupport } - resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + resultRangeReadCloser.RangeReader = func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length) if err != nil { return nil, err @@ -298,14 +301,9 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( return readSeeker, nil } - resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers} - resultLink := &model.Link{ + return &model.Link{ RangeReadCloser: resultRangeReadCloser, - Expiration: remoteLink.Expiration, - } - - return resultLink, nil - + }, nil } func (d *Crypt) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/drivers/doubao/util.go b/drivers/doubao/util.go index bc71599b..70b4231c 100644 --- a/drivers/doubao/util.go +++ b/drivers/doubao/util.go @@ -524,7 +524,6 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi if err != nil { return nil, fmt.Errorf("failed to cache file: %w", err) } - defer tempFile.Close() up(10.0) // 更新进度 // 设置并行上传 threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go index a2eac114..2ffc7153 100644 --- a/drivers/ftp/driver.go +++ b/drivers/ftp/driver.go @@ -7,6 +7,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/jlaffaye/ftp" ) @@ -66,7 +67,11 @@ func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m r := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) link := &model.Link{ - MFile: r, + MFile: &stream.RateLimitFile{ + File: r, + Limiter: stream.ServerDownloadLimit, + Ctx: ctx, + }, } return link, nil } diff --git a/drivers/halalcloud/driver.go b/drivers/halalcloud/driver.go index 87aa23c2..04f7fcdd 100644 --- a/drivers/halalcloud/driver.go +++ b/drivers/halalcloud/driver.go @@ -256,9 +256,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { length = -1 } - if err != nil { - return nil, fmt.Errorf("open download file failed: %w", err) - } oo := &openObject{ ctx: ctx, d: fileAddrs, diff --git a/drivers/halalcloud/types.go b/drivers/halalcloud/types.go index 01c06ac5..39adc5d2 100644 --- a/drivers/halalcloud/types.go +++ b/drivers/halalcloud/types.go @@ -96,7 +96,3 @@ type SteamFile struct { func (s *SteamFile) Read(p []byte) (n int, err error) { return s.file.Read(p) } - -func (s *SteamFile) Close() error { - return s.file.Close() -} diff --git a/drivers/local/driver.go b/drivers/local/driver.go index ad0682e0..64cdbb24 100644 --- a/drivers/local/driver.go +++ b/drivers/local/driver.go @@ -242,7 +242,7 @@ func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( } link.MFile = open } else { - link.MFile = model.NewNopMFile(bytes.NewReader(buf.Bytes())) + link.MFile = bytes.NewReader(buf.Bytes()) //link.Header.Set("Content-Length", strconv.Itoa(buf.Len())) } } else { diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index 76d02bd6..15d84f31 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -184,9 +184,6 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, file model.FileS if err != nil { return err } - defer func() { - _ = tempFile.Close() - }() uploader := s3manager.NewUploader(s) if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index 1a9ebef0..c0a2d476 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -2,7 +2,6 @@ package netease_music import ( "context" - "io" "net/http" "strconv" "strings" @@ -11,7 +10,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/sign" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" "github.com/OpenListTeam/OpenList/v4/server/common" @@ -55,17 +53,8 @@ func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link { } func (lrc *LyricObj) getLyricLink() *model.Link { - reader := strings.NewReader(lrc.lyric) return &model.Link{ - RangeReadCloser: &model.RangeReadCloser{ - RangeReader: func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { - if httpRange.Length < 0 { - return io.NopCloser(reader), nil - } - sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length) - return io.NopCloser(sr), nil - }, - }, + MFile: strings.NewReader(lrc.lyric), } } diff --git a/drivers/sftp/driver.go b/drivers/sftp/driver.go index 2babbe9b..4e540621 100644 --- a/drivers/sftp/driver.go +++ b/drivers/sftp/driver.go @@ -8,6 +8,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/pkg/sftp" log "github.com/sirupsen/logrus" @@ -62,10 +63,13 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* if err != nil { return nil, err } - link := &model.Link{ - MFile: remoteFile, - } - return link, nil + return &model.Link{ + MFile: &stream.RateLimitFile{ + File: remoteFile, + Limiter: stream.ServerDownloadLimit, + Ctx: ctx, + }, + }, nil } func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index a020483b..3cdfbbe4 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -8,6 +8,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/hirochachacha/go-smb2" @@ -79,11 +80,14 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m d.cleanLastConnTime() return nil, err } - link := &model.Link{ - MFile: remoteFile, - } d.updateLastConnTime() - return link, nil + return &model.Link{ + MFile: &stream.RateLimitFile{ + File: remoteFile, + Limiter: stream.ServerDownloadLimit, + Ctx: ctx, + }, + }, nil } func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/drivers/strm/driver.go b/drivers/strm/driver.go index 6985ad75..c100183a 100644 --- a/drivers/strm/driver.go +++ b/drivers/strm/driver.go @@ -114,7 +114,7 @@ func (d *Strm) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([] func (d *Strm) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { link := d.getLink(ctx, file.GetPath()) return &model.Link{ - MFile: model.NewNopMFile(strings.NewReader(link)), + MFile: strings.NewReader(link), }, nil } diff --git a/internal/archive/rardecode/utils.go b/internal/archive/rardecode/utils.go index 7d8a1c37..93a71da9 100644 --- a/internal/archive/rardecode/utils.go +++ b/internal/archive/rardecode/utils.go @@ -18,8 +18,9 @@ import ( ) type VolumeFile struct { - stream.SStreamReadAtSeeker + model.File name string + ss model.FileStreamer } func (v *VolumeFile) Name() string { @@ -27,7 +28,7 @@ func (v *VolumeFile) Name() string { } func (v *VolumeFile) Size() int64 { - return v.SStreamReadAtSeeker.GetRawStream().GetSize() + return v.ss.GetSize() } func (v *VolumeFile) Mode() fs.FileMode { @@ -35,7 +36,7 @@ func (v *VolumeFile) Mode() fs.FileMode { } func (v *VolumeFile) ModTime() time.Time { - return v.SStreamReadAtSeeker.GetRawStream().ModTime() + return v.ss.ModTime() } func (v *VolumeFile) IsDir() bool { @@ -74,7 +75,7 @@ func makeOpts(ss []*stream.SeekableStream) (string, rardecode.Option, error) { } fileName := "file.rar" fsys := &VolumeFs{parts: map[string]*VolumeFile{ - fileName: {SStreamReadAtSeeker: reader, name: fileName}, + fileName: {File: reader, name: fileName}, }} return fileName, rardecode.FileSystem(fsys), nil } else { @@ -85,7 +86,7 @@ func makeOpts(ss []*stream.SeekableStream) (string, rardecode.Option, error) { return "", nil, err } fileName := fmt.Sprintf("file.part%d.rar", i+1) - parts[fileName] = &VolumeFile{SStreamReadAtSeeker: reader, name: fileName} + parts[fileName] = &VolumeFile{File: reader, name: fileName, ss: s} } return "file.part1.rar", rardecode.FileSystem(&VolumeFs{parts: parts}), nil } diff --git a/internal/model/args.go b/internal/model/args.go index 7a1b0649..2477adc0 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -27,10 +27,9 @@ type Link struct { URL string `json:"url"` // most common way Header http.Header `json:"header"` // needed header (for url) RangeReadCloser RangeReadCloserIF `json:"-"` // recommended way if can't use URL - MFile File `json:"-"` // best for local,smb... file system, which exposes MFile + MFile io.ReadSeeker `json:"-"` // best for local,smb... file system, which exposes MFile Expiration *time.Duration // local cache expire Duration - IPCacheKey bool `json:"-"` // add ip to cache key //for accelerating request, use multi-thread downloading Concurrency int `json:"concurrency"` diff --git a/internal/model/file.go b/internal/model/file.go index ba65ef93..d3a1fa6a 100644 --- a/internal/model/file.go +++ b/internal/model/file.go @@ -7,19 +7,4 @@ type File interface { io.Reader io.ReaderAt io.Seeker - io.Closer -} - -type NopMFileIF interface { - io.Reader - io.ReaderAt - io.Seeker -} -type NopMFile struct { - NopMFileIF -} - -func (NopMFile) Close() error { return nil } -func NewNopMFile(r NopMFileIF) File { - return NopMFile{r} } diff --git a/internal/net/request.go b/internal/net/request.go index 9610c309..41640972 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -182,11 +182,10 @@ func (d *downloader) download() (io.ReadCloser, error) { defer d.m.Unlock() if closeFunc != nil { d.concurrencyFinish() - err := closeFunc() + err = closeFunc() closeFunc = nil - return err } - return nil + return err }) return resp.Body, nil } @@ -272,24 +271,30 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { // when the final reader Close, we interrupt func (d *downloader) interrupt() error { - if d.written != d.params.Range.Length { + d.m.Lock() + defer d.m.Unlock() + err := d.err + if err == nil && d.written != d.params.Range.Length { log.Debugf("Downloader interrupt before finish") - if d.getErr() == nil { - d.setErr(fmt.Errorf("interrupted")) - } + err := fmt.Errorf("interrupted") + d.err = err } - d.cancel(d.err) - defer func() { + if d.chunkChannel != nil { + d.cancel(err) close(d.chunkChannel) + d.chunkChannel = nil for _, buf := range d.bufs { buf.Close() } + d.bufs = nil if d.concurrency > 0 { d.concurrency = -d.concurrency } log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency) - }() - return d.err + } else { + log.Debug("close of closed channel") + } + return err } func (d *downloader) getBuf(id int) (b *Buf) { return d.bufs[id%len(d.bufs)] diff --git a/internal/op/archive.go b/internal/op/archive.go index 541d8970..16911f6a 100644 --- a/internal/op/archive.go +++ b/internal/op/archive.go @@ -62,8 +62,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st } baseName, ext, found := strings.Cut(obj.GetName(), ".") if !found { - if l.MFile != nil { - _ = l.MFile.Close() + if clr, ok := l.MFile.(io.Closer); ok { + _ = clr.Close() } if l.RangeReadCloser != nil { _ = l.RangeReadCloser.Close() @@ -75,8 +75,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st var e error partExt, t, e = tool.GetArchiveTool(stdpath.Ext(obj.GetName())) if e != nil { - if l.MFile != nil { - _ = l.MFile.Close() + if clr, ok := l.MFile.(io.Closer); ok { + _ = clr.Close() } if l.RangeReadCloser != nil { _ = l.RangeReadCloser.Close() @@ -86,8 +86,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st } ss, err := stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: obj}, l) if err != nil { - if l.MFile != nil { - _ = l.MFile.Close() + if clr, ok := l.MFile.(io.Closer); ok { + _ = clr.Close() } if l.RangeReadCloser != nil { _ = l.RangeReadCloser.Close() @@ -109,8 +109,8 @@ func GetArchiveToolAndStream(ctx context.Context, storage driver.Driver, path st } ss, err = stream.NewSeekableStream(stream.FileStream{Ctx: ctx, Obj: o}, l) if err != nil { - if l.MFile != nil { - _ = l.MFile.Close() + if clr, ok := l.MFile.(io.Closer); ok { + _ = clr.Close() } if l.RangeReadCloser != nil { _ = l.RangeReadCloser.Close() @@ -174,9 +174,6 @@ func getArchiveMeta(ctx context.Context, storage driver.Driver, path string, arg if !storage.Config().NoCache { Expiration := time.Minute * time.Duration(storage.GetStorage().CacheExpiration) archiveMetaProvider.Expiration = &Expiration - } else if ss[0].Link.MFile == nil { - // alias、crypt 驱动 - archiveMetaProvider.Expiration = ss[0].Link.Expiration } return obj, archiveMetaProvider, err } @@ -401,9 +398,6 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args return nil, errors.Wrapf(err, "failed extract archive") } if link.Link.Expiration != nil { - if link.Link.IPCacheKey { - key = key + ":" + args.IP - } extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration)) } return link, nil diff --git a/internal/op/fs.go b/internal/op/fs.go index adc08229..cae7d95d 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -268,9 +268,6 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li return nil, errors.Wrapf(err, "failed get link") } if link.Expiration != nil { - if link.IPCacheKey { - key = key + ":" + args.IP - } linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) } return link, nil diff --git a/internal/stream/limit.go b/internal/stream/limit.go index 0b049a93..db5da20d 100644 --- a/internal/stream/limit.go +++ b/internal/stream/limit.go @@ -135,6 +135,13 @@ func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) { return } +func (r *RateLimitFile) Close() error { + if c, ok := r.File.(io.Closer); ok { + return c.Close() + } + return nil +} + type RateLimitRangeReadCloser struct { model.RangeReadCloserIF Limiter Limiter diff --git a/internal/stream/stream.go b/internal/stream/stream.go index d55a751b..1164f80d 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -81,10 +81,7 @@ func (f *FileStream) SetExist(obj model.Obj) { // CacheFullInTempFile save all data into tmpFile. Not recommended since it wears disk, // and can't start upload until the file is written. It's not thread-safe! func (f *FileStream) CacheFullInTempFile() (model.File, error) { - if f.tmpFile != nil { - return f.tmpFile, nil - } - if file, ok := f.Reader.(model.File); ok { + if file := f.GetFile(); file != nil { return file, nil } tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize()) @@ -117,33 +114,35 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 参考 internal/net/request.go httpRange.Length = f.GetSize() - httpRange.Start } + var cache io.ReaderAt = f.GetFile() + if cache != nil { + return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil + } + size := httpRange.Start + httpRange.Length if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } - var cache io.ReaderAt = f.GetFile() - if cache == nil { - if size <= InMemoryBufMaxSizeBytes { - bufSize := min(size, f.GetSize()) - // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom - // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 - buf := make([]byte, bufSize) - n, err := io.ReadFull(f.Reader, buf) - if err != nil { - return nil, err - } - if n != int(bufSize) { - return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) - } - f.peekBuff = bytes.NewReader(buf) - f.Reader = io.MultiReader(f.peekBuff, f.Reader) - cache = f.peekBuff - } else { - var err error - cache, err = f.CacheFullInTempFile() - if err != nil { - return nil, err - } + if size <= InMemoryBufMaxSizeBytes { + bufSize := min(size, f.GetSize()) + // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom + // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 + buf := make([]byte, bufSize) + n, err := io.ReadFull(f.Reader, buf) + if err != nil { + return nil, err + } + if n != int(bufSize) { + return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) + } + f.peekBuff = bytes.NewReader(buf) + f.Reader = io.MultiReader(f.peekBuff, f.Reader) + cache = f.peekBuff + } else { + var err error + cache, err = f.CacheFullInTempFile() + if err != nil { + return nil, err } } return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil @@ -161,49 +160,34 @@ var _ model.FileStreamer = (*FileStream)(nil) // the SeekableStream object and be closed together when the SeekableStream object is closed. type SeekableStream struct { FileStream - Link *model.Link // should have one of belows to support rangeRead rangeReadCloser model.RangeReadCloserIF - mFile model.File } func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) { if len(fs.Mimetype) == 0 { fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) } - ss := &SeekableStream{FileStream: fs, Link: link} + ss := &SeekableStream{FileStream: fs} if ss.Reader != nil { - result, ok := ss.Reader.(model.File) - if ok { - ss.mFile = result - ss.Closers.Add(result) - return ss, nil - } + ss.TryAdd(ss.Reader) + return ss, nil } - if ss.Link != nil { - if ss.Link.MFile != nil { - mFile := ss.Link.MFile - if _, ok := mFile.(*os.File); !ok { - mFile = &RateLimitFile{ - File: mFile, - Limiter: ServerDownloadLimit, - Ctx: fs.Ctx, - } - } - ss.mFile = mFile - ss.Reader = mFile - ss.Closers.Add(mFile) + if link != nil { + if link.MFile != nil { + ss.Closers.TryAdd(link.MFile) + ss.Reader = link.MFile return ss, nil } - if ss.Link.RangeReadCloser != nil { + if link.RangeReadCloser != nil { ss.rangeReadCloser = &RateLimitRangeReadCloser{ - RangeReadCloserIF: ss.Link.RangeReadCloser, + RangeReadCloserIF: link.RangeReadCloser, Limiter: ServerDownloadLimit, } ss.Add(ss.rangeReadCloser) return ss, nil } - if len(ss.Link.URL) > 0 { + if len(link.URL) > 0 { rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) if err != nil { return nil, err @@ -217,9 +201,6 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) return ss, nil } } - if fs.Reader != nil { - return ss, nil - } return nil, fmt.Errorf("illegal seekableStream") } @@ -229,16 +210,10 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) // RangeRead is not thread-safe, pls use it in single thread only. func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { - if httpRange.Length == -1 { - httpRange.Length = ss.GetSize() - httpRange.Start - } - if ss.mFile != nil { - return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil - } - if ss.tmpFile != nil { - return io.NewSectionReader(ss.tmpFile, httpRange.Start, httpRange.Length), nil - } - if ss.rangeReadCloser != nil { + if ss.tmpFile == nil && ss.rangeReadCloser != nil { + if httpRange.Length == -1 { + httpRange.Length = ss.GetSize() - httpRange.Start + } rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange) if err != nil { return nil, err @@ -272,11 +247,8 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) { } func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { - if ss.tmpFile != nil { - return ss.tmpFile, nil - } - if ss.mFile != nil { - return ss.mFile, nil + if file := ss.GetFile(); file != nil { + return file, nil } tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) if err != nil { @@ -288,16 +260,6 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { return tmpF, nil } -func (ss *SeekableStream) GetFile() model.File { - if ss.tmpFile != nil { - return ss.tmpFile - } - if ss.mFile != nil { - return ss.mFile - } - return nil -} - func (f *FileStream) SetTmpFile(r *os.File) { f.Add(r) f.tmpFile = r @@ -342,11 +304,6 @@ func (r *ReaderUpdatingProgress) Close() error { return r.Reader.Close() } -type SStreamReadAtSeeker interface { - model.File - GetRawStream() *SeekableStream -} - type readerCur struct { reader io.Reader cur int64 @@ -407,7 +364,7 @@ func (r *headCache) Close() error { } func (r *RangeReadReadAtSeeker) InitHeadCache() { - if r.ss.Link.MFile == nil && r.masterOff == 0 { + if r.ss.GetFile() == nil && r.masterOff == 0 { reader := r.readers[0] r.readers = r.readers[1:] r.headCache = &headCache{readerCur: reader} @@ -415,13 +372,13 @@ func (r *RangeReadReadAtSeeker) InitHeadCache() { } } -func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (SStreamReadAtSeeker, error) { - if ss.mFile != nil { - _, err := ss.mFile.Seek(offset, io.SeekStart) +func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (model.File, error) { + if ss.GetFile() != nil { + _, err := ss.GetFile().Seek(offset, io.SeekStart) if err != nil { return nil, err } - return &FileReadAtSeeker{ss: ss}, nil + return ss.GetFile(), nil } r := &RangeReadReadAtSeeker{ ss: ss, @@ -454,10 +411,6 @@ func NewMultiReaderAt(ss []*SeekableStream) (readerutil.SizeReaderAt, error) { return readerutil.NewMultiReaderAt(readers...), nil } -func (r *RangeReadReadAtSeeker) GetRawStream() *SeekableStream { - return r.ss -} - func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) { var rc *readerCur for _, reader := range r.readers { @@ -562,31 +515,3 @@ func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) { r.masterOff += int64(n) return n, err } - -func (r *RangeReadReadAtSeeker) Close() error { - return r.ss.Close() -} - -type FileReadAtSeeker struct { - ss *SeekableStream -} - -func (f *FileReadAtSeeker) GetRawStream() *SeekableStream { - return f.ss -} - -func (f *FileReadAtSeeker) Read(p []byte) (n int, err error) { - return f.ss.mFile.Read(p) -} - -func (f *FileReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) { - return f.ss.mFile.ReadAt(p, off) -} - -func (f *FileReadAtSeeker) Seek(offset int64, whence int) (int64, error) { - return f.ss.mFile.Seek(offset, whence) -} - -func (f *FileReadAtSeeker) Close() error { - return f.ss.Close() -} diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 71adc289..dd158adb 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -153,6 +153,7 @@ func Retry(attempts int, sleep time.Duration, f func() error) (err error) { type ClosersIF interface { io.Closer Add(closer io.Closer) + TryAdd(reader io.Reader) AddClosers(closers Closers) GetClosers() Closers } @@ -177,16 +178,19 @@ func (c *Closers) Close() error { return errors.Join(errs...) } func (c *Closers) Add(closer io.Closer) { - c.closers = append(c.closers, closer) - + if closer != nil { + c.closers = append(c.closers, closer) + } } func (c *Closers) AddClosers(closers Closers) { c.closers = append(c.closers, closers.closers...) } - -func EmptyClosers() Closers { - return Closers{[]io.Closer{}} +func (c *Closers) TryAdd(reader io.Reader) { + if closer, ok := reader.(io.Closer); ok { + c.closers = append(c.closers, closer) + } } + func NewClosers(c ...io.Closer) Closers { return Closers{c} } diff --git a/server/common/proxy.go b/server/common/proxy.go index e749fc32..77ac0973 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "os" "strings" "maps" @@ -19,21 +18,15 @@ import ( func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { if link.MFile != nil { - defer link.MFile.Close() + if clr, ok := link.MFile.(io.Closer); ok { + defer clr.Close() + } attachHeader(w, file) contentType := link.Header.Get("Content-Type") if contentType != "" { w.Header().Set("Content-Type", contentType) } - mFile := link.MFile - if _, ok := mFile.(*os.File); !ok { - mFile = &stream.RateLimitFile{ - File: mFile, - Limiter: stream.ServerDownloadLimit, - Ctx: r.Context(), - } - } - http.ServeContent(w, r, file.GetName(), file.ModTime(), mFile) + http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile) return nil } else if link.RangeReadCloser != nil { attachHeader(w, file) diff --git a/server/ftp/fsread.go b/server/ftp/fsread.go index 3bf495f0..e00f3c76 100644 --- a/server/ftp/fsread.go +++ b/server/ftp/fsread.go @@ -2,6 +2,7 @@ package ftp import ( "context" + "io" fs2 "io/fs" "net/http" "os" @@ -13,13 +14,13 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/server/common" - ftpserver "github.com/fclairamb/ftpserverlib" "github.com/pkg/errors" ) type FileDownloadProxy struct { - ftpserver.FileTransfer - reader stream.SStreamReadAtSeeker + model.File + io.Closer + ctx context.Context } func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownloadProxy, error) { @@ -57,15 +58,24 @@ func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownl _ = ss.Close() return nil, err } - return &FileDownloadProxy{reader: reader}, nil + return &FileDownloadProxy{File: reader, Closer: ss, ctx: ctx}, nil } func (f *FileDownloadProxy) Read(p []byte) (n int, err error) { - n, err = f.reader.Read(p) + n, err = f.File.Read(p) if err != nil { return } - err = stream.ClientDownloadLimit.WaitN(f.reader.GetRawStream().Ctx, n) + err = stream.ClientDownloadLimit.WaitN(f.ctx, n) + return +} + +func (f *FileDownloadProxy) ReadAt(p []byte, off int64) (n int, err error) { + n, err = f.File.ReadAt(p, off) + if err != nil { + return + } + err = stream.ClientDownloadLimit.WaitN(f.ctx, n) return } @@ -73,14 +83,6 @@ func (f *FileDownloadProxy) Write(p []byte) (n int, err error) { return 0, errs.NotSupport } -func (f *FileDownloadProxy) Seek(offset int64, whence int) (int64, error) { - return f.reader.Seek(offset, whence) -} - -func (f *FileDownloadProxy) Close() error { - return f.reader.Close() -} - type OsFileInfoAdapter struct { obj model.Obj } diff --git a/server/handles/down.go b/server/handles/down.go index 26de024e..e6f44974 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -85,15 +85,15 @@ func Proxy(c *gin.Context) { } func down(c *gin.Context, link *model.Link) { - var err error - if link.MFile != nil { - defer func(ReadSeekCloser io.ReadCloser) { - err := ReadSeekCloser.Close() + if clr, ok := link.MFile.(io.Closer); ok { + defer func(clr io.Closer) { + err := clr.Close() if err != nil { - log.Errorf("close data error: %s", err) + log.Errorf("close link data error: %v", err) } - }(link.MFile) + }(clr) } + var err error c.Header("Referrer-Policy", "no-referrer") c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") if setting.GetBool(conf.ForwardDirectLinkParams) { diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index 9a38f742..96ababec 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -390,14 +390,13 @@ func Link(c *gin.Context) { common.ErrorResp(c, err, 500) return } - if link.MFile != nil { - defer func(ReadSeekCloser io.ReadCloser) { - err := ReadSeekCloser.Close() + if clr, ok := link.MFile.(io.Closer); ok { + defer func(clr io.Closer) { + err := clr.Close() if err != nil { log.Errorf("close link data error: %v", err) } - }(link.MFile) + }(clr) } common.SuccessResp(c, link) - return } diff --git a/server/s3/backend.go b/server/s3/backend.go index d1fe9b49..8912d03f 100644 --- a/server/s3/backend.go +++ b/server/s3/backend.go @@ -187,7 +187,11 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string if err != nil { return nil, err } - rdr = link.MFile + if rdr2, ok := link.MFile.(io.ReadCloser); ok { + rdr = rdr2 + } else { + rdr = io.NopCloser(link.MFile) + } } else { remoteFileSize := file.GetSize() if length >= 0 && start+length >= remoteFileSize {