From 9ac0484bc08ab93b17a15930e256b04c21e83254 Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Wed, 6 Aug 2025 13:32:37 +0800 Subject: [PATCH] perf(ftp): improve concurrent Link response; fix alias/local driver issues (#974) --- drivers/189pc/utils.go | 4 +- drivers/alias/driver.go | 97 +++++++++++----- drivers/alias/util.go | 45 +------- drivers/aliyundrive_open/upload.go | 7 +- drivers/crypt/driver.go | 4 +- drivers/doubao/util.go | 8 +- drivers/ftp/driver.go | 68 ++++++++--- drivers/ftp/meta.go | 2 +- drivers/ftp/util.go | 132 +++++++-------------- drivers/local/driver.go | 9 +- drivers/netease_music/types.go | 4 +- drivers/quark_open/util.go | 12 +- drivers/sftp/driver.go | 18 +-- drivers/smb/driver.go | 19 +-- drivers/virtual/driver.go | 4 - internal/model/args.go | 8 -- internal/op/archive.go | 14 ++- internal/op/fs.go | 19 ++- internal/stream/stream.go | 180 +++++++++++------------------ internal/stream/util.go | 24 ++-- pkg/utils/io.go | 52 +++++---- 21 files changed, 337 insertions(+), 393 deletions(-) diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index 00fbe297..e38f636a 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -550,9 +550,9 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo return err } silceMd5.Reset() - w, _ := utils.CopyWithBuffer(writers, reader) + w, err := utils.CopyWithBuffer(writers, reader) if w != size { - return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) + return fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", size, w, err) } // 计算块md5并进行hex和base64编码 md5Bytes := silceMd5.Sum(nil) diff --git a/drivers/alias/driver.go b/drivers/alias/driver.go index 5a1b6930..284cdc40 100644 --- a/drivers/alias/driver.go +++ b/drivers/alias/driver.go @@ -78,10 +78,18 @@ func (d *Alias) Get(ctx context.Context, path string) (model.Obj, error) { return nil, errs.ObjectNotFound } for _, dst := range dsts { - obj, err := d.get(ctx, path, dst, sub) - if err == nil { - return obj, nil + obj, err := fs.Get(ctx, stdpath.Join(dst, sub), &fs.GetArgs{NoLog: true}) + if err != nil { + continue } + return &model.Object{ + Path: path, + Name: obj.GetName(), + Size: obj.GetSize(), + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + HashInfo: obj.GetHash(), + }, nil } return nil, errs.ObjectNotFound } @@ -99,7 +107,27 @@ func (d *Alias) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([ var objs []model.Obj fsArgs := &fs.ListArgs{NoLog: true, Refresh: args.Refresh} for _, dst := range dsts { - tmp, err := d.list(ctx, dst, sub, fsArgs) + tmp, err := fs.List(ctx, stdpath.Join(dst, sub), fsArgs) + if err == nil { + tmp, err = utils.SliceConvert(tmp, func(obj model.Obj) (model.Obj, error) { + thumb, ok := model.GetThumb(obj) + objRes := model.Object{ + Name: obj.GetName(), + Size: obj.GetSize(), + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + } + if !ok { + return &objRes, nil + } + return &model.ObjThumb{ + Object: objRes, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + }, nil + }) + } if err == nil { objs = append(objs, tmp...) } @@ -113,43 +141,50 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if !ok { return nil, errs.ObjectNotFound } + // proxy || ftp,s3 + if common.GetApiUrl(ctx) == "" { + args.Redirect = false + } for _, dst := range dsts { reqPath := stdpath.Join(dst, sub) - link, file, err := d.link(ctx, reqPath, args) + link, fi, err := d.link(ctx, reqPath, args) if err != nil { continue } - var resultLink *model.Link - if link != nil { - resultLink = &model.Link{ - URL: link.URL, - Header: link.Header, - RangeReader: link.RangeReader, - SyncClosers: utils.NewSyncClosers(link), - ContentLength: link.ContentLength, - } - if link.MFile != nil { - resultLink.RangeReader = &model.FileRangeReader{ - RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile), - } - } - - } else { - resultLink = &model.Link{ + if link == nil { + // 重定向且需要通过代理 + return &model.Link{ URL: fmt.Sprintf("%s/p%s?sign=%s", common.GetApiUrl(ctx), utils.EncodePath(reqPath, true), sign.Sign(reqPath)), - } - + }, nil } - if !args.Redirect { - if d.DownloadConcurrency > 0 { - resultLink.Concurrency = d.DownloadConcurrency - } - if d.DownloadPartSize > 0 { - resultLink.PartSize = d.DownloadPartSize * utils.KB - } + if args.Redirect { + return link, nil + } + + resultLink := &model.Link{ + URL: link.URL, + Header: link.Header, + RangeReader: link.RangeReader, + MFile: link.MFile, + Concurrency: link.Concurrency, + PartSize: link.PartSize, + ContentLength: link.ContentLength, + SyncClosers: utils.NewSyncClosers(link), + } + if resultLink.ContentLength == 0 { + resultLink.ContentLength = fi.GetSize() + } + if resultLink.MFile != nil { + return resultLink, nil + } + if d.DownloadConcurrency > 0 { + resultLink.Concurrency = d.DownloadConcurrency + } + if d.DownloadPartSize > 0 { + resultLink.PartSize = d.DownloadPartSize * utils.KB } return resultLink, nil } diff --git a/drivers/alias/util.go b/drivers/alias/util.go index 1ae9c798..a31ec1c5 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -54,55 +54,12 @@ func (d *Alias) getRootAndPath(path string) (string, string) { return parts[0], parts[1] } -func (d *Alias) get(ctx context.Context, path string, dst, sub string) (model.Obj, error) { - obj, err := fs.Get(ctx, stdpath.Join(dst, sub), &fs.GetArgs{NoLog: true}) - if err != nil { - return nil, err - } - return &model.Object{ - Path: path, - Name: obj.GetName(), - Size: obj.GetSize(), - Modified: obj.ModTime(), - IsFolder: obj.IsDir(), - HashInfo: obj.GetHash(), - }, nil -} - -func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([]model.Obj, error) { - objs, err := fs.List(ctx, stdpath.Join(dst, sub), args) - // the obj must implement the model.SetPath interface - // return objs, err - if err != nil { - return nil, err - } - return utils.SliceConvert(objs, func(obj model.Obj) (model.Obj, error) { - thumb, ok := model.GetThumb(obj) - objRes := model.Object{ - Name: obj.GetName(), - Size: obj.GetSize(), - Modified: obj.ModTime(), - IsFolder: obj.IsDir(), - } - if !ok { - return &objRes, nil - } - return &model.ObjThumb{ - Object: objRes, - Thumbnail: model.Thumbnail{ - Thumbnail: thumb, - }, - }, nil - }) -} - func (d *Alias) link(ctx context.Context, reqPath string, args model.LinkArgs) (*model.Link, model.Obj, error) { storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath) if err != nil { return nil, nil, err } - // proxy || ftp,s3 - if !args.Redirect || len(common.GetApiUrl(ctx)) == 0 { + if !args.Redirect { return op.Link(ctx, storage, reqActualPath, args) } obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index 98852706..369e9ddb 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -137,11 +137,8 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error } buf := make([]byte, length) n, err := io.ReadFull(reader, buf) - if err == io.ErrUnexpectedEOF { - return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n) - } - if err != nil { - return "", err + if n != int(length) { + return "", fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) } return base64.StdEncoding.EncodeToString(buf), nil } diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index a480db86..4cd64348 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -292,10 +292,10 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if offset == 0 && limit > 0 { fileHeader = make([]byte, fileHeaderSize) - n, _ := io.ReadFull(remoteReader, fileHeader) + n, err := io.ReadFull(remoteReader, fileHeader) if n != fileHeaderSize { fileHeader = nil - return nil, fmt.Errorf("can't read data, expected=%d, got=%d", fileHeaderSize, n) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", fileHeaderSize, n, err) } if limit <= fileHeaderSize { remoteReader.Close() diff --git a/drivers/doubao/util.go b/drivers/doubao/util.go index 39d55134..7dd1da2c 100644 --- a/drivers/doubao/util.go +++ b/drivers/doubao/util.go @@ -460,9 +460,9 @@ func (d *Doubao) Upload(ctx context.Context, config *UploadConfig, dstDir model. // 计算CRC32 crc32Hash := crc32.NewIEEE() - w, _ := utils.CopyWithBuffer(crc32Hash, reader) + w, err := utils.CopyWithBuffer(crc32Hash, reader) if w != file.GetSize() { - return nil, fmt.Errorf("can't read data, expected=%d, got=%d", file.GetSize(), w) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", file.GetSize(), w, err) } crc32Value := hex.EncodeToString(crc32Hash.Sum(nil)) @@ -588,9 +588,9 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi return err } hash.Reset() - w, _ := utils.CopyWithBuffer(hash, reader) + w, err := utils.CopyWithBuffer(hash, reader) if w != size { - return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) + return fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", size, w, err) } crc32Value = hex.EncodeToString(hash.Sum(nil)) rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go index 647ee1b7..0ed6ac2a 100644 --- a/drivers/ftp/driver.go +++ b/drivers/ftp/driver.go @@ -2,12 +2,16 @@ package ftp import ( "context" + "io" stdpath "path" + "sync" + "time" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/jlaffaye/ftp" ) @@ -16,6 +20,9 @@ type FTP struct { model.Storage Addition conn *ftp.ServerConn + + ctx context.Context + cancel context.CancelFunc } func (d *FTP) Config() driver.Config { @@ -27,12 +34,16 @@ func (d *FTP) GetAddition() driver.Additional { } func (d *FTP) Init(ctx context.Context) error { - return d._login() + d.ctx, d.cancel = context.WithCancel(context.Background()) + var err error + d.conn, err = d._login(ctx) + return err } func (d *FTP) Drop(ctx context.Context) error { if d.conn != nil { - _ = d.conn.Logout() + _ = d.conn.Quit() + d.cancel() } return nil } @@ -61,26 +72,53 @@ func (d *FTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m return res, nil } -func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { - if err := d.login(); err != nil { +func (d *FTP) Link(_ context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + ctx, cancel := context.WithCancel(context.Background()) + conn, err := d._login(ctx) + if err != nil { + cancel() return nil, err } + close := func() error { + _ = conn.Quit() + cancel() + return nil + } - remoteFile := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) - if remoteFile != nil && !d.Config().OnlyLinkMFile { - return &model.Link{ - RangeReader: &model.FileRangeReader{ - RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)), - }, - SyncClosers: utils.NewSyncClosers(remoteFile), + path := encode(file.GetPath(), d.Encoding) + size := file.GetSize() + mu := &sync.Mutex{} + resultRangeReader := func(context context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + length := httpRange.Length + if length < 0 || httpRange.Start+length > size { + length = size - httpRange.Start + } + mu.Lock() + defer mu.Unlock() + r, err := conn.RetrFrom(path, uint64(httpRange.Start)) + if err != nil { + _ = conn.Quit() + conn, err = d._login(ctx) + if err == nil { + r, err = conn.RetrFrom(path, uint64(httpRange.Start)) + } + if err != nil { + return nil, err + } + } + r.SetDeadline(time.Now().Add(time.Second)) + return &FileReader{ + Response: r, + Reader: io.LimitReader(r, length), + ctx: context, }, nil } + return &model.Link{ - MFile: &stream.RateLimitFile{ - File: remoteFile, - Limiter: stream.ServerDownloadLimit, - Ctx: ctx, + RangeReader: &model.FileRangeReader{ + RangeReaderIF: stream.RateLimitRangeReaderFunc(resultRangeReader), }, + SyncClosers: utils.NewSyncClosers(utils.CloseFunc(close)), }, nil } diff --git a/drivers/ftp/meta.go b/drivers/ftp/meta.go index 6e8cc107..8f30776c 100644 --- a/drivers/ftp/meta.go +++ b/drivers/ftp/meta.go @@ -33,7 +33,7 @@ type Addition struct { var config = driver.Config{ Name: "FTP", LocalSort: true, - OnlyLinkMFile: true, + OnlyLinkMFile: false, DefaultRoot: "/", NoLinkURL: true, } diff --git a/drivers/ftp/util.go b/drivers/ftp/util.go index c81803d6..5945a218 100644 --- a/drivers/ftp/util.go +++ b/drivers/ftp/util.go @@ -1,14 +1,15 @@ package ftp import ( + "context" + "errors" "fmt" "io" "os" - "sync" - "sync/atomic" "time" "github.com/OpenListTeam/OpenList/v4/pkg/singleflight" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/jlaffaye/ftp" ) @@ -16,111 +17,56 @@ import ( func (d *FTP) login() error { _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (any, error) { - return nil, d._login() + var err error + if d.conn != nil { + err = d.conn.NoOp() + if err != nil { + d.conn.Quit() + d.conn = nil + } + } + if d.conn == nil { + d.conn, err = d._login(d.ctx) + } + return nil, err }) return err } -func (d *FTP) _login() error { - - if d.conn != nil { - _, err := d.conn.CurrentDir() - if err == nil { - return nil - } - } - conn, err := ftp.Dial(d.Address, ftp.DialWithShutTimeout(10*time.Second)) +func (d *FTP) _login(ctx context.Context) (*ftp.ServerConn, error) { + conn, err := ftp.Dial(d.Address, ftp.DialWithShutTimeout(10*time.Second), ftp.DialWithContext(ctx)) if err != nil { - return err + return nil, err } err = conn.Login(d.Username, d.Password) if err != nil { - return err + conn.Quit() + return nil, err } - d.conn = conn - return nil + return conn, nil } -// FileReader An FTP file reader that implements io.MFile for seeking. type FileReader struct { - conn *ftp.ServerConn - resp *ftp.Response - offset atomic.Int64 - readAtOffset int64 - mu sync.Mutex - path string - size int64 + *ftp.Response + io.Reader + ctx context.Context } -func NewFileReader(conn *ftp.ServerConn, path string, size int64) *FileReader { - return &FileReader{ - conn: conn, - path: path, - size: size, - } -} - -func (r *FileReader) Read(buf []byte) (n int, err error) { - n, err = r.ReadAt(buf, r.offset.Load()) - r.offset.Add(int64(n)) - return -} - -func (r *FileReader) ReadAt(buf []byte, off int64) (n int, err error) { - if off < 0 { - return -1, os.ErrInvalid - } - r.mu.Lock() - defer r.mu.Unlock() - - if off != r.readAtOffset { - //have to restart the connection, to correct offset - _ = r.resp.Close() - r.resp = nil - } - - if r.resp == nil { - r.resp, err = r.conn.RetrFrom(r.path, uint64(off)) - r.readAtOffset = off - if err != nil { - return 0, err +func (r *FileReader) Read(buf []byte) (int, error) { + n := 0 + for n < len(buf) { + w, err := r.Reader.Read(buf[n:]) + if utils.IsCanceled(r.ctx) { + return n, r.ctx.Err() + } + n += w + if errors.Is(err, os.ErrDeadlineExceeded) { + r.Response.SetDeadline(time.Now().Add(time.Second)) + continue + } + if err != nil || w == 0 { + return n, err } } - - n, err = r.resp.Read(buf) - r.readAtOffset += int64(n) - return -} - -func (r *FileReader) Seek(offset int64, whence int) (int64, error) { - oldOffset := r.offset.Load() - var newOffset int64 - switch whence { - case io.SeekStart: - newOffset = offset - case io.SeekCurrent: - newOffset = oldOffset + offset - case io.SeekEnd: - return r.size, nil - default: - return -1, os.ErrInvalid - } - - if newOffset < 0 { - // offset out of range - return oldOffset, os.ErrInvalid - } - if newOffset == oldOffset { - // offset not changed, so return directly - return oldOffset, nil - } - r.offset.Store(newOffset) - return newOffset, nil -} - -func (r *FileReader) Close() error { - if r.resp != nil { - return r.resp.Close() - } - return nil + return n, nil } diff --git a/drivers/local/driver.go b/drivers/local/driver.go index a19534b7..5defd647 100644 --- a/drivers/local/driver.go +++ b/drivers/local/driver.go @@ -245,13 +245,12 @@ func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if err != nil { return nil, err } + link.ContentLength = file.GetSize() link.MFile = open } - if link.MFile != nil && !d.Config().OnlyLinkMFile { - link.AddIfCloser(link.MFile) - link.RangeReader = &model.FileRangeReader{ - RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile), - } + link.AddIfCloser(link.MFile) + if !d.Config().OnlyLinkMFile { + link.RangeReader = stream.GetRangeReaderFromMFile(link.ContentLength, link.MFile) link.MFile = nil } return link, nil diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index 1175ff60..c3898c2f 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -55,9 +55,7 @@ func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link { func (lrc *LyricObj) getLyricLink() *model.Link { return &model.Link{ - RangeReader: &model.FileRangeReader{ - RangeReaderIF: stream.GetRangeReaderFromMFile(int64(len(lrc.lyric)), strings.NewReader(lrc.lyric)), - }, + RangeReader: stream.GetRangeReaderFromMFile(int64(len(lrc.lyric)), strings.NewReader(lrc.lyric)), } } diff --git a/drivers/quark_open/util.go b/drivers/quark_open/util.go index 98e76e8d..78f4e4a2 100644 --- a/drivers/quark_open/util.go +++ b/drivers/quark_open/util.go @@ -8,14 +8,15 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" - "github.com/google/uuid" "io" "net/http" "strconv" "strings" "time" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/google/uuid" + "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" @@ -244,11 +245,8 @@ func (d *QuarkOpen) generateProofCode(file model.FileStreamer, proofSeed string, // 读取数据 buf := make([]byte, length) n, err := io.ReadFull(reader, buf) - if errors.Is(err, io.ErrUnexpectedEOF) { - return "", fmt.Errorf("can't read data, expected=%d, got=%d", length, n) - } - if err != nil { - return "", fmt.Errorf("failed to read data: %w", err) + if n != int(length) { + return "", fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) } // Base64编码 diff --git a/drivers/sftp/driver.go b/drivers/sftp/driver.go index e0cdda86..7de24248 100644 --- a/drivers/sftp/driver.go +++ b/drivers/sftp/driver.go @@ -63,20 +63,20 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* if err != nil { return nil, err } - if remoteFile != nil && !d.Config().OnlyLinkMFile { + mFile := &stream.RateLimitFile{ + File: remoteFile, + Limiter: stream.ServerDownloadLimit, + Ctx: ctx, + } + if !d.Config().OnlyLinkMFile { return &model.Link{ - RangeReader: &model.FileRangeReader{ - RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)), - }, + RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile), SyncClosers: utils.NewSyncClosers(remoteFile), }, nil } return &model.Link{ - MFile: &stream.RateLimitFile{ - File: remoteFile, - Limiter: stream.ServerDownloadLimit, - Ctx: ctx, - }, + MFile: mFile, + SyncClosers: utils.NewSyncClosers(remoteFile), }, nil } diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index d38c9cef..3e12f122 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -81,19 +81,20 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m return nil, err } d.updateLastConnTime() - if remoteFile != nil && !d.Config().OnlyLinkMFile { + mFile := &stream.RateLimitFile{ + File: remoteFile, + Limiter: stream.ServerDownloadLimit, + Ctx: ctx, + } + if !d.Config().OnlyLinkMFile { return &model.Link{ - RangeReader: &model.FileRangeReader{ - RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)), - }, + RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile), + SyncClosers: utils.NewSyncClosers(remoteFile), }, nil } return &model.Link{ - MFile: &stream.RateLimitFile{ - File: remoteFile, - Limiter: stream.ServerDownloadLimit, - Ctx: ctx, - }, + MFile: mFile, + SyncClosers: utils.NewSyncClosers(remoteFile), }, nil } diff --git a/drivers/virtual/driver.go b/drivers/virtual/driver.go index a66f6f77..1d14427c 100644 --- a/drivers/virtual/driver.go +++ b/drivers/virtual/driver.go @@ -54,10 +54,6 @@ func (f DummyMFile) ReadAt(p []byte, off int64) (n int, err error) { return f.Reader.Read(p) } -func (f DummyMFile) Close() error { - return nil -} - func (DummyMFile) Seek(offset int64, whence int) (int64, error) { return offset, nil } diff --git a/internal/model/args.go b/internal/model/args.go index 05cdf02b..e655882a 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -2,7 +2,6 @@ package model import ( "context" - "errors" "io" "net/http" "time" @@ -40,13 +39,6 @@ type Link struct { utils.SyncClosers `json:"-"` } -func (l *Link) Close() error { - if clr, ok := l.MFile.(io.Closer); ok { - return errors.Join(clr.Close(), l.SyncClosers.Close()) - } - return l.SyncClosers.Close() -} - type OtherArgs struct { Obj Obj Method string diff --git a/internal/op/archive.go b/internal/op/archive.go index 8c5d2bab..964e9397 100644 --- a/internal/op/archive.go +++ b/internal/op/archive.go @@ -372,11 +372,16 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args } var forget any + var linkM *extractLink fn := func() (*extractLink, error) { link, err := driverExtract(ctx, storage, path, args) if err != nil { return nil, errors.Wrapf(err, "failed extract archive") } + if link.MFile != nil && forget != nil { + linkM = link + return nil, errLinkMFileCache + } if link.Link.Expiration != nil { extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration)) } @@ -406,11 +411,18 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args link.AcquireReference() } } + if err == errLinkMFileCache { + if linkM != nil { + return linkM.Link, linkM.Obj, nil + } + forget = nil + link, err = fn() + } if err != nil { return nil, nil, err } - return link.Link, link.Obj, err + return link.Link, link.Obj, nil } func driverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*extractLink, error) { diff --git a/internal/op/fs.go b/internal/op/fs.go index 0b3f5283..bdf0567b 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -2,6 +2,7 @@ package op import ( "context" + stderrors "errors" stdpath "path" "slices" "strings" @@ -250,6 +251,7 @@ func GetUnwrap(ctx context.Context, storage driver.Driver, path string) (model.O var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16)) var linkG = singleflight.Group[*model.Link]{Remember: true} +var errLinkMFileCache = stderrors.New("ErrLinkMFileCache") // Link get link, if is an url. should have an expiry time func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { @@ -292,11 +294,16 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li } var forget any + var linkM *model.Link fn := func() (*model.Link, error) { link, err := storage.Link(ctx, file, args) if err != nil { return nil, errors.Wrapf(err, "failed get link") } + if link.MFile != nil && forget != nil { + linkM = link + return nil, errLinkMFileCache + } if link.Expiration != nil { linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) } @@ -326,11 +333,19 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li link.AcquireReference() } } + + if err == errLinkMFileCache { + if linkM != nil { + return linkM, file, nil + } + forget = nil + link, err = fn() + } + if err != nil { return nil, nil, err } - - return link, file, err + return link, file, nil } // Other api diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 387bf036..932975a4 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -8,13 +8,13 @@ import ( "io" "math" "os" + "sync" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/sirupsen/logrus" "go4.org/readerutil" ) @@ -127,10 +127,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { buf := make([]byte, bufSize) n, err := io.ReadFull(f.Reader, buf) if err != nil { - return nil, err - } - if n != int(bufSize) { - return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) } f.peekBuff = bytes.NewReader(buf) f.Reader = io.MultiReader(f.peekBuff, f.Reader) @@ -234,7 +231,7 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) { } rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1}) if err != nil { - return 0, nil + return 0, err } ss.Reader = rc } @@ -299,70 +296,48 @@ func (r *ReaderUpdatingProgress) Close() error { return r.Reader.Close() } -type readerCur struct { - reader io.Reader - cur int64 -} - type RangeReadReadAtSeeker struct { ss *SeekableStream masterOff int64 - readers []*readerCur + readerMap sync.Map headCache *headCache } type headCache struct { - *readerCur - bufs [][]byte + reader io.Reader + bufs [][]byte } -func (c *headCache) read(p []byte) (n int, err error) { - pL := len(p) - logrus.Debugf("headCache read_%d", pL) - if c.cur < int64(pL) { - bufL := int64(pL) - c.cur - buf := make([]byte, bufL) - lr := io.LimitReader(c.reader, bufL) - off := 0 - for c.cur < int64(pL) { - n, err = lr.Read(buf[off:]) - off += n - c.cur += int64(n) - if err == io.EOF && off == int(bufL) { - err = nil - } - if err != nil { - break - } +func (c *headCache) head(p []byte) (int, error) { + n := 0 + for _, buf := range c.bufs { + if len(buf)+n >= len(p) { + n += copy(p[n:], buf[:len(p)-n]) + return n, nil + } else { + n += copy(p[n:], buf) } + } + w, err := io.ReadAtLeast(c.reader, p[n:], 1) + if w > 0 { + buf := make([]byte, w) + copy(buf, p[n:n+w]) c.bufs = append(c.bufs, buf) + n += w } - n = 0 - if c.cur >= int64(pL) { - for i := 0; n < pL; i++ { - buf := c.bufs[i] - r := len(buf) - if n+r > pL { - r = pL - n - } - n += copy(p[n:], buf[:r]) - } - } - return + return n, err } + func (r *headCache) Close() error { - for i := range r.bufs { - r.bufs[i] = nil - } + clear(r.bufs) r.bufs = nil return nil } func (r *RangeReadReadAtSeeker) InitHeadCache() { if r.ss.GetFile() == nil && r.masterOff == 0 { - reader := r.readers[0] - r.readers = r.readers[1:] - r.headCache = &headCache{readerCur: reader} + value, _ := r.readerMap.LoadAndDelete(int64(0)) + r.headCache = &headCache{reader: value.(io.Reader)} r.ss.Closers.Add(r.headCache) } } @@ -388,8 +363,7 @@ func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (mode return nil, err } } else { - rc := &readerCur{reader: ss, cur: offset} - r.readers = append(r.readers, rc) + r.readerMap.Store(int64(offset), ss) } return r, nil } @@ -406,72 +380,64 @@ func NewMultiReaderAt(ss []*SeekableStream) (readerutil.SizeReaderAt, error) { return readerutil.NewMultiReaderAt(readers...), nil } -func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) { - var rc *readerCur - for _, reader := range r.readers { - if reader.cur == -1 { - continue +func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (io.Reader, error) { + var rr io.Reader + var cur int64 = -1 + r.readerMap.Range(func(key, value any) bool { + k := key.(int64) + if off == k { + cur = k + rr = value.(io.Reader) + return false } - if reader.cur == off { - // logrus.Debugf("getReaderAtOffset match_%d", off) - return reader, nil - } - if reader.cur > 0 && off >= reader.cur && (rc == nil || reader.cur < rc.cur) { - rc = reader + if off > k && off-k <= 4*utils.MB && (rr == nil || k < cur) { + rr = value.(io.Reader) + cur = k } + return true + }) + if cur >= 0 { + r.readerMap.Delete(int64(cur)) } - if rc != nil && off-rc.cur <= utils.MB { - n, err := utils.CopyWithBufferN(io.Discard, rc.reader, off-rc.cur) - rc.cur += n - if err == io.EOF && rc.cur == off { - err = nil - } - if err == nil { - logrus.Debugf("getReaderAtOffset old_%d", off) - return rc, nil - } - rc.cur = -1 + if off == int64(cur) { + // logrus.Debugf("getReaderAtOffset match_%d", off) + return rr, nil } - logrus.Debugf("getReaderAtOffset new_%d", off) - // Range请求不能超过文件大小,有些云盘处理不了就会返回整个文件 - reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: r.ss.GetSize() - off}) + if rr != nil { + n, _ := utils.CopyWithBufferN(io.Discard, rr, off-cur) + cur += n + if cur == off { + // logrus.Debugf("getReaderAtOffset old_%d", off) + return rr, nil + } + } + // logrus.Debugf("getReaderAtOffset new_%d", off) + + reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: -1}) if err != nil { return nil, err } - rc = &readerCur{reader: reader, cur: off} - r.readers = append(r.readers, rc) - return rc, nil + return reader, nil } -func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (int, error) { +func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) { if off == 0 && r.headCache != nil { - return r.headCache.read(p) + return r.headCache.head(p) } - rc, err := r.getReaderAtOffset(off) + var rr io.Reader + rr, err = r.getReaderAtOffset(off) if err != nil { return 0, err } - n, num := 0, 0 - for num < len(p) { - n, err = rc.reader.Read(p[num:]) - rc.cur += int64(n) - num += n - if err == nil { - continue - } - if err == io.EOF { - // io.EOF是reader读取完了 - rc.cur = -1 - // yeka/zip包 没有处理EOF,我们要兼容 - // https://github.com/yeka/zip/blob/03d6312748a9d6e0bc0c9a7275385c09f06d9c14/reader.go#L433 - if num == len(p) { - err = nil - } - } - break + n, err = io.ReadAtLeast(rr, p, 1) + off += int64(n) + if err == nil { + r.readerMap.Store(int64(off), rr) + } else { + rr = nil } - return num, err + return n, err } func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) { @@ -498,15 +464,7 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) { } func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) { - if r.masterOff == 0 && r.headCache != nil { - return r.headCache.read(p) - } - rc, err := r.getReaderAtOffset(r.masterOff) - if err != nil { - return 0, err - } - n, err = rc.reader.Read(p) - rc.cur += int64(n) + n, err = r.ReadAt(p, r.masterOff) r.masterOff += int64(n) return n, err } diff --git a/internal/stream/util.go b/internal/stream/util.go index 77b23802..d2de46ac 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -26,7 +26,7 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) { if link.MFile != nil { - return &model.FileRangeReader{RangeReaderIF: GetRangeReaderFromMFile(size, link.MFile)}, nil + return GetRangeReaderFromMFile(size, link.MFile), nil } if link.Concurrency > 0 || link.PartSize > 0 { down := net.NewDownloader(func(d *net.Downloader) { @@ -97,13 +97,16 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, return RateLimitRangeReaderFunc(rangeReader), nil } -func GetRangeReaderFromMFile(size int64, file model.File) RangeReaderFunc { - return func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { - length := httpRange.Length - if length < 0 || httpRange.Start+length > size { - length = size - httpRange.Start - } - return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil +// RangeReaderIF.RangeRead返回的io.ReadCloser保留file的签名。 +func GetRangeReaderFromMFile(size int64, file model.File) model.RangeReaderIF { + return &model.FileRangeReader{ + RangeReaderIF: RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + length := httpRange.Length + if length < 0 || httpRange.Start+length > size { + length = size - httpRange.Start + } + return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil + }), } } @@ -227,11 +230,8 @@ func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionRead tempBuf := ss.bufPool.Get().([]byte) buf = tempBuf[:length] n, err := io.ReadFull(ss.file, buf) - if err != nil { - return nil, err - } if int64(n) != length { - return nil, fmt.Errorf("can't read data, expected=%d, got=%d", length, n) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) } ss.off += int64(n) off = 0 diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 93b3ddbc..dd0e3fac 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io" + "math" "sync" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -164,6 +166,7 @@ func (c *Closers) Close() error { errs = append(errs, closer.Close()) } } + clear(*c) *c = (*c)[:0] return errors.Join(errs...) } @@ -191,32 +194,32 @@ type SyncClosersIF interface { type SyncClosers struct { closers []io.Closer - mu sync.Mutex - ref int + ref atomic.Int32 } var _ SyncClosersIF = (*SyncClosers)(nil) func (c *SyncClosers) AcquireReference() bool { - c.mu.Lock() - defer c.mu.Unlock() - if len(c.closers) == 0 { - return false + ref := c.ref.Add(1) + if ref > 0 { + // log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, ref) + return true } - c.ref++ - log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, c.ref) - return true + c.ref.Store(math.MinInt16) + return false } func (c *SyncClosers) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - defer log.Debugf("SyncClosers.Close %p,ref=%d\n", c, c.ref) - if c.ref > 1 { - c.ref-- + ref := c.ref.Add(-1) + if ref < -1 { + c.ref.Store(math.MinInt16) return nil } - c.ref = 0 + // log.Debugf("SyncClosers.Close %p,ref=%d\n", c, ref+1) + if ref > 0 { + return nil + } + c.ref.Store(math.MinInt16) var errs []error for _, closer := range c.closers { @@ -224,23 +227,26 @@ func (c *SyncClosers) Close() error { errs = append(errs, closer.Close()) } } - c.closers = c.closers[:0] + clear(c.closers) + c.closers = nil return errors.Join(errs...) } func (c *SyncClosers) Add(closer io.Closer) { if closer != nil { - c.mu.Lock() + if c.ref.Load() < 0 { + panic("Not reusable") + } c.closers = append(c.closers, closer) - c.mu.Unlock() } } func (c *SyncClosers) AddIfCloser(a any) { if closer, ok := a.(io.Closer); ok { - c.mu.Lock() + if c.ref.Load() < 0 { + panic("Not reusable") + } c.closers = append(c.closers, closer) - c.mu.Unlock() } } @@ -278,11 +284,7 @@ var IoBuffPool = &sync.Pool{ func CopyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { buff := IoBuffPool.Get().([]byte) defer IoBuffPool.Put(buff) - written, err = io.CopyBuffer(dst, src, buff) - if err != nil { - return - } - return written, nil + return io.CopyBuffer(dst, src, buff) } func CopyWithBufferN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {