perf(ftp): improve concurrent Link response; fix alias/local driver issues (#974)

This commit is contained in:
j2rong4cn
2025-08-06 13:32:37 +08:00
committed by GitHub
parent 8cf15183a0
commit 9ac0484bc0
21 changed files with 337 additions and 393 deletions

View File

@ -550,9 +550,9 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo
return err return err
} }
silceMd5.Reset() silceMd5.Reset()
w, _ := utils.CopyWithBuffer(writers, reader) w, err := utils.CopyWithBuffer(writers, reader)
if w != size { if w != size {
return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) return fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", size, w, err)
} }
// 计算块md5并进行hex和base64编码 // 计算块md5并进行hex和base64编码
md5Bytes := silceMd5.Sum(nil) md5Bytes := silceMd5.Sum(nil)

View File

@ -78,10 +78,18 @@ func (d *Alias) Get(ctx context.Context, path string) (model.Obj, error) {
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound
} }
for _, dst := range dsts { for _, dst := range dsts {
obj, err := d.get(ctx, path, dst, sub) obj, err := fs.Get(ctx, stdpath.Join(dst, sub), &fs.GetArgs{NoLog: true})
if err == nil { if err != nil {
return obj, nil continue
} }
return &model.Object{
Path: path,
Name: obj.GetName(),
Size: obj.GetSize(),
Modified: obj.ModTime(),
IsFolder: obj.IsDir(),
HashInfo: obj.GetHash(),
}, nil
} }
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound
} }
@ -99,7 +107,27 @@ func (d *Alias) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([
var objs []model.Obj var objs []model.Obj
fsArgs := &fs.ListArgs{NoLog: true, Refresh: args.Refresh} fsArgs := &fs.ListArgs{NoLog: true, Refresh: args.Refresh}
for _, dst := range dsts { for _, dst := range dsts {
tmp, err := d.list(ctx, dst, sub, fsArgs) tmp, err := fs.List(ctx, stdpath.Join(dst, sub), fsArgs)
if err == nil {
tmp, err = utils.SliceConvert(tmp, func(obj model.Obj) (model.Obj, error) {
thumb, ok := model.GetThumb(obj)
objRes := model.Object{
Name: obj.GetName(),
Size: obj.GetSize(),
Modified: obj.ModTime(),
IsFolder: obj.IsDir(),
}
if !ok {
return &objRes, nil
}
return &model.ObjThumb{
Object: objRes,
Thumbnail: model.Thumbnail{
Thumbnail: thumb,
},
}, nil
})
}
if err == nil { if err == nil {
objs = append(objs, tmp...) objs = append(objs, tmp...)
} }
@ -113,44 +141,51 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if !ok { if !ok {
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound
} }
// proxy || ftp,s3
if common.GetApiUrl(ctx) == "" {
args.Redirect = false
}
for _, dst := range dsts { for _, dst := range dsts {
reqPath := stdpath.Join(dst, sub) reqPath := stdpath.Join(dst, sub)
link, file, err := d.link(ctx, reqPath, args) link, fi, err := d.link(ctx, reqPath, args)
if err != nil { if err != nil {
continue continue
} }
var resultLink *model.Link if link == nil {
if link != nil { // 重定向且需要通过代理
resultLink = &model.Link{ return &model.Link{
URL: link.URL,
Header: link.Header,
RangeReader: link.RangeReader,
SyncClosers: utils.NewSyncClosers(link),
ContentLength: link.ContentLength,
}
if link.MFile != nil {
resultLink.RangeReader = &model.FileRangeReader{
RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile),
}
}
} else {
resultLink = &model.Link{
URL: fmt.Sprintf("%s/p%s?sign=%s", URL: fmt.Sprintf("%s/p%s?sign=%s",
common.GetApiUrl(ctx), common.GetApiUrl(ctx),
utils.EncodePath(reqPath, true), utils.EncodePath(reqPath, true),
sign.Sign(reqPath)), sign.Sign(reqPath)),
}, nil
}
if args.Redirect {
return link, nil
} }
resultLink := &model.Link{
URL: link.URL,
Header: link.Header,
RangeReader: link.RangeReader,
MFile: link.MFile,
Concurrency: link.Concurrency,
PartSize: link.PartSize,
ContentLength: link.ContentLength,
SyncClosers: utils.NewSyncClosers(link),
}
if resultLink.ContentLength == 0 {
resultLink.ContentLength = fi.GetSize()
}
if resultLink.MFile != nil {
return resultLink, nil
} }
if !args.Redirect {
if d.DownloadConcurrency > 0 { if d.DownloadConcurrency > 0 {
resultLink.Concurrency = d.DownloadConcurrency resultLink.Concurrency = d.DownloadConcurrency
} }
if d.DownloadPartSize > 0 { if d.DownloadPartSize > 0 {
resultLink.PartSize = d.DownloadPartSize * utils.KB resultLink.PartSize = d.DownloadPartSize * utils.KB
} }
}
return resultLink, nil return resultLink, nil
} }
return nil, errs.ObjectNotFound return nil, errs.ObjectNotFound

View File

@ -54,55 +54,12 @@ func (d *Alias) getRootAndPath(path string) (string, string) {
return parts[0], parts[1] return parts[0], parts[1]
} }
func (d *Alias) get(ctx context.Context, path string, dst, sub string) (model.Obj, error) {
obj, err := fs.Get(ctx, stdpath.Join(dst, sub), &fs.GetArgs{NoLog: true})
if err != nil {
return nil, err
}
return &model.Object{
Path: path,
Name: obj.GetName(),
Size: obj.GetSize(),
Modified: obj.ModTime(),
IsFolder: obj.IsDir(),
HashInfo: obj.GetHash(),
}, nil
}
func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([]model.Obj, error) {
objs, err := fs.List(ctx, stdpath.Join(dst, sub), args)
// the obj must implement the model.SetPath interface
// return objs, err
if err != nil {
return nil, err
}
return utils.SliceConvert(objs, func(obj model.Obj) (model.Obj, error) {
thumb, ok := model.GetThumb(obj)
objRes := model.Object{
Name: obj.GetName(),
Size: obj.GetSize(),
Modified: obj.ModTime(),
IsFolder: obj.IsDir(),
}
if !ok {
return &objRes, nil
}
return &model.ObjThumb{
Object: objRes,
Thumbnail: model.Thumbnail{
Thumbnail: thumb,
},
}, nil
})
}
func (d *Alias) link(ctx context.Context, reqPath string, args model.LinkArgs) (*model.Link, model.Obj, error) { func (d *Alias) link(ctx context.Context, reqPath string, args model.LinkArgs) (*model.Link, model.Obj, error) {
storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath) storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// proxy || ftp,s3 if !args.Redirect {
if !args.Redirect || len(common.GetApiUrl(ctx)) == 0 {
return op.Link(ctx, storage, reqActualPath, args) return op.Link(ctx, storage, reqActualPath, args)
} }
obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})

View File

@ -137,11 +137,8 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error
} }
buf := make([]byte, length) buf := make([]byte, length)
n, err := io.ReadFull(reader, buf) n, err := io.ReadFull(reader, buf)
if err == io.ErrUnexpectedEOF { if n != int(length) {
return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n) return "", fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err)
}
if err != nil {
return "", err
} }
return base64.StdEncoding.EncodeToString(buf), nil return base64.StdEncoding.EncodeToString(buf), nil
} }

View File

@ -292,10 +292,10 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if offset == 0 && limit > 0 { if offset == 0 && limit > 0 {
fileHeader = make([]byte, fileHeaderSize) fileHeader = make([]byte, fileHeaderSize)
n, _ := io.ReadFull(remoteReader, fileHeader) n, err := io.ReadFull(remoteReader, fileHeader)
if n != fileHeaderSize { if n != fileHeaderSize {
fileHeader = nil fileHeader = nil
return nil, fmt.Errorf("can't read data, expected=%d, got=%d", fileHeaderSize, n) return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", fileHeaderSize, n, err)
} }
if limit <= fileHeaderSize { if limit <= fileHeaderSize {
remoteReader.Close() remoteReader.Close()

View File

@ -460,9 +460,9 @@ func (d *Doubao) Upload(ctx context.Context, config *UploadConfig, dstDir model.
// 计算CRC32 // 计算CRC32
crc32Hash := crc32.NewIEEE() crc32Hash := crc32.NewIEEE()
w, _ := utils.CopyWithBuffer(crc32Hash, reader) w, err := utils.CopyWithBuffer(crc32Hash, reader)
if w != file.GetSize() { if w != file.GetSize() {
return nil, fmt.Errorf("can't read data, expected=%d, got=%d", file.GetSize(), w) return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", file.GetSize(), w, err)
} }
crc32Value := hex.EncodeToString(crc32Hash.Sum(nil)) crc32Value := hex.EncodeToString(crc32Hash.Sum(nil))
@ -588,9 +588,9 @@ func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fi
return err return err
} }
hash.Reset() hash.Reset()
w, _ := utils.CopyWithBuffer(hash, reader) w, err := utils.CopyWithBuffer(hash, reader)
if w != size { if w != size {
return fmt.Errorf("can't read data, expected=%d, got=%d", size, w) return fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", size, w, err)
} }
crc32Value = hex.EncodeToString(hash.Sum(nil)) crc32Value = hex.EncodeToString(hash.Sum(nil))
rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader) rateLimitedRd = driver.NewLimitedUploadStream(ctx, reader)

View File

@ -2,12 +2,16 @@ package ftp
import ( import (
"context" "context"
"io"
stdpath "path" stdpath "path"
"sync"
"time"
"github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/internal/stream"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/jlaffaye/ftp" "github.com/jlaffaye/ftp"
) )
@ -16,6 +20,9 @@ type FTP struct {
model.Storage model.Storage
Addition Addition
conn *ftp.ServerConn conn *ftp.ServerConn
ctx context.Context
cancel context.CancelFunc
} }
func (d *FTP) Config() driver.Config { func (d *FTP) Config() driver.Config {
@ -27,12 +34,16 @@ func (d *FTP) GetAddition() driver.Additional {
} }
func (d *FTP) Init(ctx context.Context) error { func (d *FTP) Init(ctx context.Context) error {
return d._login() d.ctx, d.cancel = context.WithCancel(context.Background())
var err error
d.conn, err = d._login(ctx)
return err
} }
func (d *FTP) Drop(ctx context.Context) error { func (d *FTP) Drop(ctx context.Context) error {
if d.conn != nil { if d.conn != nil {
_ = d.conn.Logout() _ = d.conn.Quit()
d.cancel()
} }
return nil return nil
} }
@ -61,26 +72,53 @@ func (d *FTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m
return res, nil return res, nil
} }
func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { func (d *FTP) Link(_ context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
if err := d.login(); err != nil { ctx, cancel := context.WithCancel(context.Background())
conn, err := d._login(ctx)
if err != nil {
cancel()
return nil, err return nil, err
} }
close := func() error {
_ = conn.Quit()
cancel()
return nil
}
remoteFile := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) path := encode(file.GetPath(), d.Encoding)
if remoteFile != nil && !d.Config().OnlyLinkMFile { size := file.GetSize()
return &model.Link{ mu := &sync.Mutex{}
RangeReader: &model.FileRangeReader{ resultRangeReader := func(context context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)), length := httpRange.Length
}, if length < 0 || httpRange.Start+length > size {
SyncClosers: utils.NewSyncClosers(remoteFile), length = size - httpRange.Start
}
mu.Lock()
defer mu.Unlock()
r, err := conn.RetrFrom(path, uint64(httpRange.Start))
if err != nil {
_ = conn.Quit()
conn, err = d._login(ctx)
if err == nil {
r, err = conn.RetrFrom(path, uint64(httpRange.Start))
}
if err != nil {
return nil, err
}
}
r.SetDeadline(time.Now().Add(time.Second))
return &FileReader{
Response: r,
Reader: io.LimitReader(r, length),
ctx: context,
}, nil }, nil
} }
return &model.Link{ return &model.Link{
MFile: &stream.RateLimitFile{ RangeReader: &model.FileRangeReader{
File: remoteFile, RangeReaderIF: stream.RateLimitRangeReaderFunc(resultRangeReader),
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
}, },
SyncClosers: utils.NewSyncClosers(utils.CloseFunc(close)),
}, nil }, nil
} }

View File

@ -33,7 +33,7 @@ type Addition struct {
var config = driver.Config{ var config = driver.Config{
Name: "FTP", Name: "FTP",
LocalSort: true, LocalSort: true,
OnlyLinkMFile: true, OnlyLinkMFile: false,
DefaultRoot: "/", DefaultRoot: "/",
NoLinkURL: true, NoLinkURL: true,
} }

View File

@ -1,14 +1,15 @@
package ftp package ftp
import ( import (
"context"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"sync"
"sync/atomic"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight" "github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/jlaffaye/ftp" "github.com/jlaffaye/ftp"
) )
@ -16,111 +17,56 @@ import (
func (d *FTP) login() error { func (d *FTP) login() error {
_, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (any, error) { _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("FTP.login:%p", d), func() (any, error) {
return nil, d._login() var err error
if d.conn != nil {
err = d.conn.NoOp()
if err != nil {
d.conn.Quit()
d.conn = nil
}
}
if d.conn == nil {
d.conn, err = d._login(d.ctx)
}
return nil, err
}) })
return err return err
} }
func (d *FTP) _login() error { func (d *FTP) _login(ctx context.Context) (*ftp.ServerConn, error) {
conn, err := ftp.Dial(d.Address, ftp.DialWithShutTimeout(10*time.Second), ftp.DialWithContext(ctx))
if d.conn != nil {
_, err := d.conn.CurrentDir()
if err == nil {
return nil
}
}
conn, err := ftp.Dial(d.Address, ftp.DialWithShutTimeout(10*time.Second))
if err != nil { if err != nil {
return err return nil, err
} }
err = conn.Login(d.Username, d.Password) err = conn.Login(d.Username, d.Password)
if err != nil { if err != nil {
return err conn.Quit()
return nil, err
} }
d.conn = conn return conn, nil
return nil
} }
// FileReader An FTP file reader that implements io.MFile for seeking.
type FileReader struct { type FileReader struct {
conn *ftp.ServerConn *ftp.Response
resp *ftp.Response io.Reader
offset atomic.Int64 ctx context.Context
readAtOffset int64
mu sync.Mutex
path string
size int64
} }
func NewFileReader(conn *ftp.ServerConn, path string, size int64) *FileReader { func (r *FileReader) Read(buf []byte) (int, error) {
return &FileReader{ n := 0
conn: conn, for n < len(buf) {
path: path, w, err := r.Reader.Read(buf[n:])
size: size, if utils.IsCanceled(r.ctx) {
return n, r.ctx.Err()
} }
} n += w
if errors.Is(err, os.ErrDeadlineExceeded) {
func (r *FileReader) Read(buf []byte) (n int, err error) { r.Response.SetDeadline(time.Now().Add(time.Second))
n, err = r.ReadAt(buf, r.offset.Load()) continue
r.offset.Add(int64(n)) }
return if err != nil || w == 0 {
} return n, err
}
func (r *FileReader) ReadAt(buf []byte, off int64) (n int, err error) { }
if off < 0 { return n, nil
return -1, os.ErrInvalid
}
r.mu.Lock()
defer r.mu.Unlock()
if off != r.readAtOffset {
//have to restart the connection, to correct offset
_ = r.resp.Close()
r.resp = nil
}
if r.resp == nil {
r.resp, err = r.conn.RetrFrom(r.path, uint64(off))
r.readAtOffset = off
if err != nil {
return 0, err
}
}
n, err = r.resp.Read(buf)
r.readAtOffset += int64(n)
return
}
func (r *FileReader) Seek(offset int64, whence int) (int64, error) {
oldOffset := r.offset.Load()
var newOffset int64
switch whence {
case io.SeekStart:
newOffset = offset
case io.SeekCurrent:
newOffset = oldOffset + offset
case io.SeekEnd:
return r.size, nil
default:
return -1, os.ErrInvalid
}
if newOffset < 0 {
// offset out of range
return oldOffset, os.ErrInvalid
}
if newOffset == oldOffset {
// offset not changed, so return directly
return oldOffset, nil
}
r.offset.Store(newOffset)
return newOffset, nil
}
func (r *FileReader) Close() error {
if r.resp != nil {
return r.resp.Close()
}
return nil
} }

View File

@ -245,13 +245,12 @@ func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
link.ContentLength = file.GetSize()
link.MFile = open link.MFile = open
} }
if link.MFile != nil && !d.Config().OnlyLinkMFile {
link.AddIfCloser(link.MFile) link.AddIfCloser(link.MFile)
link.RangeReader = &model.FileRangeReader{ if !d.Config().OnlyLinkMFile {
RangeReaderIF: stream.GetRangeReaderFromMFile(file.GetSize(), link.MFile), link.RangeReader = stream.GetRangeReaderFromMFile(link.ContentLength, link.MFile)
}
link.MFile = nil link.MFile = nil
} }
return link, nil return link, nil

View File

@ -55,9 +55,7 @@ func (lrc *LyricObj) getProxyLink(ctx context.Context) *model.Link {
func (lrc *LyricObj) getLyricLink() *model.Link { func (lrc *LyricObj) getLyricLink() *model.Link {
return &model.Link{ return &model.Link{
RangeReader: &model.FileRangeReader{ RangeReader: stream.GetRangeReaderFromMFile(int64(len(lrc.lyric)), strings.NewReader(lrc.lyric)),
RangeReaderIF: stream.GetRangeReaderFromMFile(int64(len(lrc.lyric)), strings.NewReader(lrc.lyric)),
},
} }
} }

View File

@ -8,14 +8,15 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/google/uuid"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/google/uuid"
"github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/drivers/base"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/op"
@ -244,11 +245,8 @@ func (d *QuarkOpen) generateProofCode(file model.FileStreamer, proofSeed string,
// 读取数据 // 读取数据
buf := make([]byte, length) buf := make([]byte, length)
n, err := io.ReadFull(reader, buf) n, err := io.ReadFull(reader, buf)
if errors.Is(err, io.ErrUnexpectedEOF) { if n != int(length) {
return "", fmt.Errorf("can't read data, expected=%d, got=%d", length, n) return "", fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err)
}
if err != nil {
return "", fmt.Errorf("failed to read data: %w", err)
} }
// Base64编码 // Base64编码

View File

@ -63,20 +63,20 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
if remoteFile != nil && !d.Config().OnlyLinkMFile { mFile := &stream.RateLimitFile{
File: remoteFile,
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
}
if !d.Config().OnlyLinkMFile {
return &model.Link{ return &model.Link{
RangeReader: &model.FileRangeReader{ RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)),
},
SyncClosers: utils.NewSyncClosers(remoteFile), SyncClosers: utils.NewSyncClosers(remoteFile),
}, nil }, nil
} }
return &model.Link{ return &model.Link{
MFile: &stream.RateLimitFile{ MFile: mFile,
File: remoteFile, SyncClosers: utils.NewSyncClosers(remoteFile),
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
},
}, nil }, nil
} }

View File

@ -81,19 +81,20 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
return nil, err return nil, err
} }
d.updateLastConnTime() d.updateLastConnTime()
if remoteFile != nil && !d.Config().OnlyLinkMFile { mFile := &stream.RateLimitFile{
return &model.Link{
RangeReader: &model.FileRangeReader{
RangeReaderIF: stream.RateLimitRangeReaderFunc(stream.GetRangeReaderFromMFile(file.GetSize(), remoteFile)),
},
}, nil
}
return &model.Link{
MFile: &stream.RateLimitFile{
File: remoteFile, File: remoteFile,
Limiter: stream.ServerDownloadLimit, Limiter: stream.ServerDownloadLimit,
Ctx: ctx, Ctx: ctx,
}, }
if !d.Config().OnlyLinkMFile {
return &model.Link{
RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
SyncClosers: utils.NewSyncClosers(remoteFile),
}, nil
}
return &model.Link{
MFile: mFile,
SyncClosers: utils.NewSyncClosers(remoteFile),
}, nil }, nil
} }

View File

@ -54,10 +54,6 @@ func (f DummyMFile) ReadAt(p []byte, off int64) (n int, err error) {
return f.Reader.Read(p) return f.Reader.Read(p)
} }
func (f DummyMFile) Close() error {
return nil
}
func (DummyMFile) Seek(offset int64, whence int) (int64, error) { func (DummyMFile) Seek(offset int64, whence int) (int64, error) {
return offset, nil return offset, nil
} }

View File

@ -2,7 +2,6 @@ package model
import ( import (
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"time" "time"
@ -40,13 +39,6 @@ type Link struct {
utils.SyncClosers `json:"-"` utils.SyncClosers `json:"-"`
} }
func (l *Link) Close() error {
if clr, ok := l.MFile.(io.Closer); ok {
return errors.Join(clr.Close(), l.SyncClosers.Close())
}
return l.SyncClosers.Close()
}
type OtherArgs struct { type OtherArgs struct {
Obj Obj Obj Obj
Method string Method string

View File

@ -372,11 +372,16 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
} }
var forget any var forget any
var linkM *extractLink
fn := func() (*extractLink, error) { fn := func() (*extractLink, error) {
link, err := driverExtract(ctx, storage, path, args) link, err := driverExtract(ctx, storage, path, args)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed extract archive") return nil, errors.Wrapf(err, "failed extract archive")
} }
if link.MFile != nil && forget != nil {
linkM = link
return nil, errLinkMFileCache
}
if link.Link.Expiration != nil { if link.Link.Expiration != nil {
extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration)) extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration))
} }
@ -406,11 +411,18 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
link.AcquireReference() link.AcquireReference()
} }
} }
if err == errLinkMFileCache {
if linkM != nil {
return linkM.Link, linkM.Obj, nil
}
forget = nil
link, err = fn()
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return link.Link, link.Obj, err return link.Link, link.Obj, nil
} }
func driverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*extractLink, error) { func driverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*extractLink, error) {

View File

@ -2,6 +2,7 @@ package op
import ( import (
"context" "context"
stderrors "errors"
stdpath "path" stdpath "path"
"slices" "slices"
"strings" "strings"
@ -250,6 +251,7 @@ func GetUnwrap(ctx context.Context, storage driver.Driver, path string) (model.O
var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16)) var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16))
var linkG = singleflight.Group[*model.Link]{Remember: true} var linkG = singleflight.Group[*model.Link]{Remember: true}
var errLinkMFileCache = stderrors.New("ErrLinkMFileCache")
// Link get link, if is an url. should have an expiry time // Link get link, if is an url. should have an expiry time
func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
@ -292,11 +294,16 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
} }
var forget any var forget any
var linkM *model.Link
fn := func() (*model.Link, error) { fn := func() (*model.Link, error) {
link, err := storage.Link(ctx, file, args) link, err := storage.Link(ctx, file, args)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed get link") return nil, errors.Wrapf(err, "failed get link")
} }
if link.MFile != nil && forget != nil {
linkM = link
return nil, errLinkMFileCache
}
if link.Expiration != nil { if link.Expiration != nil {
linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration))
} }
@ -326,11 +333,19 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
link.AcquireReference() link.AcquireReference()
} }
} }
if err == errLinkMFileCache {
if linkM != nil {
return linkM, file, nil
}
forget = nil
link, err = fn()
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return link, file, nil
return link, file, err
} }
// Other api // Other api

View File

@ -8,13 +8,13 @@ import (
"io" "io"
"math" "math"
"os" "os"
"sync"
"github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/sirupsen/logrus"
"go4.org/readerutil" "go4.org/readerutil"
) )
@ -127,10 +127,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
buf := make([]byte, bufSize) buf := make([]byte, bufSize)
n, err := io.ReadFull(f.Reader, buf) n, err := io.ReadFull(f.Reader, buf)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err)
}
if n != int(bufSize) {
return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n)
} }
f.peekBuff = bytes.NewReader(buf) f.peekBuff = bytes.NewReader(buf)
f.Reader = io.MultiReader(f.peekBuff, f.Reader) f.Reader = io.MultiReader(f.peekBuff, f.Reader)
@ -234,7 +231,7 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
} }
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1}) rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1})
if err != nil { if err != nil {
return 0, nil return 0, err
} }
ss.Reader = rc ss.Reader = rc
} }
@ -299,70 +296,48 @@ func (r *ReaderUpdatingProgress) Close() error {
return r.Reader.Close() return r.Reader.Close()
} }
type readerCur struct {
reader io.Reader
cur int64
}
type RangeReadReadAtSeeker struct { type RangeReadReadAtSeeker struct {
ss *SeekableStream ss *SeekableStream
masterOff int64 masterOff int64
readers []*readerCur readerMap sync.Map
headCache *headCache headCache *headCache
} }
type headCache struct { type headCache struct {
*readerCur reader io.Reader
bufs [][]byte bufs [][]byte
} }
func (c *headCache) read(p []byte) (n int, err error) { func (c *headCache) head(p []byte) (int, error) {
pL := len(p) n := 0
logrus.Debugf("headCache read_%d", pL) for _, buf := range c.bufs {
if c.cur < int64(pL) { if len(buf)+n >= len(p) {
bufL := int64(pL) - c.cur n += copy(p[n:], buf[:len(p)-n])
buf := make([]byte, bufL) return n, nil
lr := io.LimitReader(c.reader, bufL) } else {
off := 0 n += copy(p[n:], buf)
for c.cur < int64(pL) {
n, err = lr.Read(buf[off:])
off += n
c.cur += int64(n)
if err == io.EOF && off == int(bufL) {
err = nil
}
if err != nil {
break
} }
} }
w, err := io.ReadAtLeast(c.reader, p[n:], 1)
if w > 0 {
buf := make([]byte, w)
copy(buf, p[n:n+w])
c.bufs = append(c.bufs, buf) c.bufs = append(c.bufs, buf)
n += w
} }
n = 0 return n, err
if c.cur >= int64(pL) {
for i := 0; n < pL; i++ {
buf := c.bufs[i]
r := len(buf)
if n+r > pL {
r = pL - n
}
n += copy(p[n:], buf[:r])
}
}
return
} }
func (r *headCache) Close() error { func (r *headCache) Close() error {
for i := range r.bufs { clear(r.bufs)
r.bufs[i] = nil
}
r.bufs = nil r.bufs = nil
return nil return nil
} }
func (r *RangeReadReadAtSeeker) InitHeadCache() { func (r *RangeReadReadAtSeeker) InitHeadCache() {
if r.ss.GetFile() == nil && r.masterOff == 0 { if r.ss.GetFile() == nil && r.masterOff == 0 {
reader := r.readers[0] value, _ := r.readerMap.LoadAndDelete(int64(0))
r.readers = r.readers[1:] r.headCache = &headCache{reader: value.(io.Reader)}
r.headCache = &headCache{readerCur: reader}
r.ss.Closers.Add(r.headCache) r.ss.Closers.Add(r.headCache)
} }
} }
@ -388,8 +363,7 @@ func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (mode
return nil, err return nil, err
} }
} else { } else {
rc := &readerCur{reader: ss, cur: offset} r.readerMap.Store(int64(offset), ss)
r.readers = append(r.readers, rc)
} }
return r, nil return r, nil
} }
@ -406,72 +380,64 @@ func NewMultiReaderAt(ss []*SeekableStream) (readerutil.SizeReaderAt, error) {
return readerutil.NewMultiReaderAt(readers...), nil return readerutil.NewMultiReaderAt(readers...), nil
} }
func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) { func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (io.Reader, error) {
var rc *readerCur var rr io.Reader
for _, reader := range r.readers { var cur int64 = -1
if reader.cur == -1 { r.readerMap.Range(func(key, value any) bool {
continue k := key.(int64)
if off == k {
cur = k
rr = value.(io.Reader)
return false
} }
if reader.cur == off { if off > k && off-k <= 4*utils.MB && (rr == nil || k < cur) {
rr = value.(io.Reader)
cur = k
}
return true
})
if cur >= 0 {
r.readerMap.Delete(int64(cur))
}
if off == int64(cur) {
// logrus.Debugf("getReaderAtOffset match_%d", off) // logrus.Debugf("getReaderAtOffset match_%d", off)
return reader, nil return rr, nil
} }
if reader.cur > 0 && off >= reader.cur && (rc == nil || reader.cur < rc.cur) {
rc = reader
}
}
if rc != nil && off-rc.cur <= utils.MB {
n, err := utils.CopyWithBufferN(io.Discard, rc.reader, off-rc.cur)
rc.cur += n
if err == io.EOF && rc.cur == off {
err = nil
}
if err == nil {
logrus.Debugf("getReaderAtOffset old_%d", off)
return rc, nil
}
rc.cur = -1
}
logrus.Debugf("getReaderAtOffset new_%d", off)
// Range请求不能超过文件大小有些云盘处理不了就会返回整个文件 if rr != nil {
reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: r.ss.GetSize() - off}) n, _ := utils.CopyWithBufferN(io.Discard, rr, off-cur)
cur += n
if cur == off {
// logrus.Debugf("getReaderAtOffset old_%d", off)
return rr, nil
}
}
// logrus.Debugf("getReaderAtOffset new_%d", off)
reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: -1})
if err != nil { if err != nil {
return nil, err return nil, err
} }
rc = &readerCur{reader: reader, cur: off} return reader, nil
r.readers = append(r.readers, rc)
return rc, nil
} }
func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (int, error) { func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) {
if off == 0 && r.headCache != nil { if off == 0 && r.headCache != nil {
return r.headCache.read(p) return r.headCache.head(p)
} }
rc, err := r.getReaderAtOffset(off) var rr io.Reader
rr, err = r.getReaderAtOffset(off)
if err != nil { if err != nil {
return 0, err return 0, err
} }
n, num := 0, 0 n, err = io.ReadAtLeast(rr, p, 1)
for num < len(p) { off += int64(n)
n, err = rc.reader.Read(p[num:])
rc.cur += int64(n)
num += n
if err == nil { if err == nil {
continue r.readerMap.Store(int64(off), rr)
} else {
rr = nil
} }
if err == io.EOF { return n, err
// io.EOF是reader读取完了
rc.cur = -1
// yeka/zip包 没有处理EOF我们要兼容
// https://github.com/yeka/zip/blob/03d6312748a9d6e0bc0c9a7275385c09f06d9c14/reader.go#L433
if num == len(p) {
err = nil
}
}
break
}
return num, err
} }
func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) { func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
@ -498,15 +464,7 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
} }
func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) { func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) {
if r.masterOff == 0 && r.headCache != nil { n, err = r.ReadAt(p, r.masterOff)
return r.headCache.read(p)
}
rc, err := r.getReaderAtOffset(r.masterOff)
if err != nil {
return 0, err
}
n, err = rc.reader.Read(p)
rc.cur += int64(n)
r.masterOff += int64(n) r.masterOff += int64(n)
return n, err return n, err
} }

View File

@ -26,7 +26,7 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran
func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) { func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) {
if link.MFile != nil { if link.MFile != nil {
return &model.FileRangeReader{RangeReaderIF: GetRangeReaderFromMFile(size, link.MFile)}, nil return GetRangeReaderFromMFile(size, link.MFile), nil
} }
if link.Concurrency > 0 || link.PartSize > 0 { if link.Concurrency > 0 || link.PartSize > 0 {
down := net.NewDownloader(func(d *net.Downloader) { down := net.NewDownloader(func(d *net.Downloader) {
@ -97,13 +97,16 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
return RateLimitRangeReaderFunc(rangeReader), nil return RateLimitRangeReaderFunc(rangeReader), nil
} }
func GetRangeReaderFromMFile(size int64, file model.File) RangeReaderFunc { // RangeReaderIF.RangeRead返回的io.ReadCloser保留file的签名。
return func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { func GetRangeReaderFromMFile(size int64, file model.File) model.RangeReaderIF {
return &model.FileRangeReader{
RangeReaderIF: RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length length := httpRange.Length
if length < 0 || httpRange.Start+length > size { if length < 0 || httpRange.Start+length > size {
length = size - httpRange.Start length = size - httpRange.Start
} }
return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil
}),
} }
} }
@ -227,11 +230,8 @@ func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionRead
tempBuf := ss.bufPool.Get().([]byte) tempBuf := ss.bufPool.Get().([]byte)
buf = tempBuf[:length] buf = tempBuf[:length]
n, err := io.ReadFull(ss.file, buf) n, err := io.ReadFull(ss.file, buf)
if err != nil {
return nil, err
}
if int64(n) != length { if int64(n) != length {
return nil, fmt.Errorf("can't read data, expected=%d, got=%d", length, n) return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err)
} }
ss.off += int64(n) ss.off += int64(n)
off = 0 off = 0

View File

@ -6,7 +6,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"sync" "sync"
"sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -164,6 +166,7 @@ func (c *Closers) Close() error {
errs = append(errs, closer.Close()) errs = append(errs, closer.Close())
} }
} }
clear(*c)
*c = (*c)[:0] *c = (*c)[:0]
return errors.Join(errs...) return errors.Join(errs...)
} }
@ -191,32 +194,32 @@ type SyncClosersIF interface {
type SyncClosers struct { type SyncClosers struct {
closers []io.Closer closers []io.Closer
mu sync.Mutex ref atomic.Int32
ref int
} }
var _ SyncClosersIF = (*SyncClosers)(nil) var _ SyncClosersIF = (*SyncClosers)(nil)
func (c *SyncClosers) AcquireReference() bool { func (c *SyncClosers) AcquireReference() bool {
c.mu.Lock() ref := c.ref.Add(1)
defer c.mu.Unlock() if ref > 0 {
if len(c.closers) == 0 { // log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, ref)
return false
}
c.ref++
log.Debugf("SyncClosers.AcquireReference %p,ref=%d\n", c, c.ref)
return true return true
}
c.ref.Store(math.MinInt16)
return false
} }
func (c *SyncClosers) Close() error { func (c *SyncClosers) Close() error {
c.mu.Lock() ref := c.ref.Add(-1)
defer c.mu.Unlock() if ref < -1 {
defer log.Debugf("SyncClosers.Close %p,ref=%d\n", c, c.ref) c.ref.Store(math.MinInt16)
if c.ref > 1 {
c.ref--
return nil return nil
} }
c.ref = 0 // log.Debugf("SyncClosers.Close %p,ref=%d\n", c, ref+1)
if ref > 0 {
return nil
}
c.ref.Store(math.MinInt16)
var errs []error var errs []error
for _, closer := range c.closers { for _, closer := range c.closers {
@ -224,23 +227,26 @@ func (c *SyncClosers) Close() error {
errs = append(errs, closer.Close()) errs = append(errs, closer.Close())
} }
} }
c.closers = c.closers[:0] clear(c.closers)
c.closers = nil
return errors.Join(errs...) return errors.Join(errs...)
} }
func (c *SyncClosers) Add(closer io.Closer) { func (c *SyncClosers) Add(closer io.Closer) {
if closer != nil { if closer != nil {
c.mu.Lock() if c.ref.Load() < 0 {
panic("Not reusable")
}
c.closers = append(c.closers, closer) c.closers = append(c.closers, closer)
c.mu.Unlock()
} }
} }
func (c *SyncClosers) AddIfCloser(a any) { func (c *SyncClosers) AddIfCloser(a any) {
if closer, ok := a.(io.Closer); ok { if closer, ok := a.(io.Closer); ok {
c.mu.Lock() if c.ref.Load() < 0 {
panic("Not reusable")
}
c.closers = append(c.closers, closer) c.closers = append(c.closers, closer)
c.mu.Unlock()
} }
} }
@ -278,11 +284,7 @@ var IoBuffPool = &sync.Pool{
func CopyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { func CopyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buff := IoBuffPool.Get().([]byte) buff := IoBuffPool.Get().([]byte)
defer IoBuffPool.Put(buff) defer IoBuffPool.Put(buff)
written, err = io.CopyBuffer(dst, src, buff) return io.CopyBuffer(dst, src, buff)
if err != nil {
return
}
return written, nil
} }
func CopyWithBufferN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { func CopyWithBufferN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {