perf(ftp): improve concurrent Link response; fix alias/local driver issues (#974)

This commit is contained in:
j2rong4cn
2025-08-06 13:32:37 +08:00
committed by GitHub
parent 8cf15183a0
commit 9ac0484bc0
21 changed files with 337 additions and 393 deletions

View File

@ -2,7 +2,6 @@ package model
import (
"context"
"errors"
"io"
"net/http"
"time"
@ -40,13 +39,6 @@ type Link struct {
utils.SyncClosers `json:"-"`
}
func (l *Link) Close() error {
if clr, ok := l.MFile.(io.Closer); ok {
return errors.Join(clr.Close(), l.SyncClosers.Close())
}
return l.SyncClosers.Close()
}
type OtherArgs struct {
Obj Obj
Method string

View File

@ -372,11 +372,16 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
}
var forget any
var linkM *extractLink
fn := func() (*extractLink, error) {
link, err := driverExtract(ctx, storage, path, args)
if err != nil {
return nil, errors.Wrapf(err, "failed extract archive")
}
if link.MFile != nil && forget != nil {
linkM = link
return nil, errLinkMFileCache
}
if link.Link.Expiration != nil {
extractCache.Set(key, link, cache.WithEx[*extractLink](*link.Link.Expiration))
}
@ -406,11 +411,18 @@ func DriverExtract(ctx context.Context, storage driver.Driver, path string, args
link.AcquireReference()
}
}
if err == errLinkMFileCache {
if linkM != nil {
return linkM.Link, linkM.Obj, nil
}
forget = nil
link, err = fn()
}
if err != nil {
return nil, nil, err
}
return link.Link, link.Obj, err
return link.Link, link.Obj, nil
}
func driverExtract(ctx context.Context, storage driver.Driver, path string, args model.ArchiveInnerArgs) (*extractLink, error) {

View File

@ -2,6 +2,7 @@ package op
import (
"context"
stderrors "errors"
stdpath "path"
"slices"
"strings"
@ -250,6 +251,7 @@ func GetUnwrap(ctx context.Context, storage driver.Driver, path string) (model.O
var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16))
var linkG = singleflight.Group[*model.Link]{Remember: true}
var errLinkMFileCache = stderrors.New("ErrLinkMFileCache")
// Link get link, if is an url. should have an expiry time
func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
@ -292,11 +294,16 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
}
var forget any
var linkM *model.Link
fn := func() (*model.Link, error) {
link, err := storage.Link(ctx, file, args)
if err != nil {
return nil, errors.Wrapf(err, "failed get link")
}
if link.MFile != nil && forget != nil {
linkM = link
return nil, errLinkMFileCache
}
if link.Expiration != nil {
linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration))
}
@ -326,11 +333,19 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li
link.AcquireReference()
}
}
if err == errLinkMFileCache {
if linkM != nil {
return linkM, file, nil
}
forget = nil
link, err = fn()
}
if err != nil {
return nil, nil, err
}
return link, file, err
return link, file, nil
}
// Other api

View File

@ -8,13 +8,13 @@ import (
"io"
"math"
"os"
"sync"
"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/sirupsen/logrus"
"go4.org/readerutil"
)
@ -127,10 +127,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
buf := make([]byte, bufSize)
n, err := io.ReadFull(f.Reader, buf)
if err != nil {
return nil, err
}
if n != int(bufSize) {
return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n)
return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err)
}
f.peekBuff = bytes.NewReader(buf)
f.Reader = io.MultiReader(f.peekBuff, f.Reader)
@ -234,7 +231,7 @@ func (ss *SeekableStream) Read(p []byte) (n int, err error) {
}
rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1})
if err != nil {
return 0, nil
return 0, err
}
ss.Reader = rc
}
@ -299,70 +296,48 @@ func (r *ReaderUpdatingProgress) Close() error {
return r.Reader.Close()
}
type readerCur struct {
reader io.Reader
cur int64
}
type RangeReadReadAtSeeker struct {
ss *SeekableStream
masterOff int64
readers []*readerCur
readerMap sync.Map
headCache *headCache
}
type headCache struct {
*readerCur
bufs [][]byte
reader io.Reader
bufs [][]byte
}
func (c *headCache) read(p []byte) (n int, err error) {
pL := len(p)
logrus.Debugf("headCache read_%d", pL)
if c.cur < int64(pL) {
bufL := int64(pL) - c.cur
buf := make([]byte, bufL)
lr := io.LimitReader(c.reader, bufL)
off := 0
for c.cur < int64(pL) {
n, err = lr.Read(buf[off:])
off += n
c.cur += int64(n)
if err == io.EOF && off == int(bufL) {
err = nil
}
if err != nil {
break
}
func (c *headCache) head(p []byte) (int, error) {
n := 0
for _, buf := range c.bufs {
if len(buf)+n >= len(p) {
n += copy(p[n:], buf[:len(p)-n])
return n, nil
} else {
n += copy(p[n:], buf)
}
}
w, err := io.ReadAtLeast(c.reader, p[n:], 1)
if w > 0 {
buf := make([]byte, w)
copy(buf, p[n:n+w])
c.bufs = append(c.bufs, buf)
n += w
}
n = 0
if c.cur >= int64(pL) {
for i := 0; n < pL; i++ {
buf := c.bufs[i]
r := len(buf)
if n+r > pL {
r = pL - n
}
n += copy(p[n:], buf[:r])
}
}
return
return n, err
}
func (r *headCache) Close() error {
for i := range r.bufs {
r.bufs[i] = nil
}
clear(r.bufs)
r.bufs = nil
return nil
}
func (r *RangeReadReadAtSeeker) InitHeadCache() {
if r.ss.GetFile() == nil && r.masterOff == 0 {
reader := r.readers[0]
r.readers = r.readers[1:]
r.headCache = &headCache{readerCur: reader}
value, _ := r.readerMap.LoadAndDelete(int64(0))
r.headCache = &headCache{reader: value.(io.Reader)}
r.ss.Closers.Add(r.headCache)
}
}
@ -388,8 +363,7 @@ func NewReadAtSeeker(ss *SeekableStream, offset int64, forceRange ...bool) (mode
return nil, err
}
} else {
rc := &readerCur{reader: ss, cur: offset}
r.readers = append(r.readers, rc)
r.readerMap.Store(int64(offset), ss)
}
return r, nil
}
@ -406,72 +380,64 @@ func NewMultiReaderAt(ss []*SeekableStream) (readerutil.SizeReaderAt, error) {
return readerutil.NewMultiReaderAt(readers...), nil
}
func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (*readerCur, error) {
var rc *readerCur
for _, reader := range r.readers {
if reader.cur == -1 {
continue
func (r *RangeReadReadAtSeeker) getReaderAtOffset(off int64) (io.Reader, error) {
var rr io.Reader
var cur int64 = -1
r.readerMap.Range(func(key, value any) bool {
k := key.(int64)
if off == k {
cur = k
rr = value.(io.Reader)
return false
}
if reader.cur == off {
// logrus.Debugf("getReaderAtOffset match_%d", off)
return reader, nil
}
if reader.cur > 0 && off >= reader.cur && (rc == nil || reader.cur < rc.cur) {
rc = reader
if off > k && off-k <= 4*utils.MB && (rr == nil || k < cur) {
rr = value.(io.Reader)
cur = k
}
return true
})
if cur >= 0 {
r.readerMap.Delete(int64(cur))
}
if rc != nil && off-rc.cur <= utils.MB {
n, err := utils.CopyWithBufferN(io.Discard, rc.reader, off-rc.cur)
rc.cur += n
if err == io.EOF && rc.cur == off {
err = nil
}
if err == nil {
logrus.Debugf("getReaderAtOffset old_%d", off)
return rc, nil
}
rc.cur = -1
if off == int64(cur) {
// logrus.Debugf("getReaderAtOffset match_%d", off)
return rr, nil
}
logrus.Debugf("getReaderAtOffset new_%d", off)
// Range请求不能超过文件大小有些云盘处理不了就会返回整个文件
reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: r.ss.GetSize() - off})
if rr != nil {
n, _ := utils.CopyWithBufferN(io.Discard, rr, off-cur)
cur += n
if cur == off {
// logrus.Debugf("getReaderAtOffset old_%d", off)
return rr, nil
}
}
// logrus.Debugf("getReaderAtOffset new_%d", off)
reader, err := r.ss.RangeRead(http_range.Range{Start: off, Length: -1})
if err != nil {
return nil, err
}
rc = &readerCur{reader: reader, cur: off}
r.readers = append(r.readers, rc)
return rc, nil
return reader, nil
}
func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (int, error) {
func (r *RangeReadReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) {
if off == 0 && r.headCache != nil {
return r.headCache.read(p)
return r.headCache.head(p)
}
rc, err := r.getReaderAtOffset(off)
var rr io.Reader
rr, err = r.getReaderAtOffset(off)
if err != nil {
return 0, err
}
n, num := 0, 0
for num < len(p) {
n, err = rc.reader.Read(p[num:])
rc.cur += int64(n)
num += n
if err == nil {
continue
}
if err == io.EOF {
// io.EOF是reader读取完了
rc.cur = -1
// yeka/zip包 没有处理EOF我们要兼容
// https://github.com/yeka/zip/blob/03d6312748a9d6e0bc0c9a7275385c09f06d9c14/reader.go#L433
if num == len(p) {
err = nil
}
}
break
n, err = io.ReadAtLeast(rr, p, 1)
off += int64(n)
if err == nil {
r.readerMap.Store(int64(off), rr)
} else {
rr = nil
}
return num, err
return n, err
}
func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
@ -498,15 +464,7 @@ func (r *RangeReadReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
}
func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) {
if r.masterOff == 0 && r.headCache != nil {
return r.headCache.read(p)
}
rc, err := r.getReaderAtOffset(r.masterOff)
if err != nil {
return 0, err
}
n, err = rc.reader.Read(p)
rc.cur += int64(n)
n, err = r.ReadAt(p, r.masterOff)
r.masterOff += int64(n)
return n, err
}

View File

@ -26,7 +26,7 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran
func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) {
if link.MFile != nil {
return &model.FileRangeReader{RangeReaderIF: GetRangeReaderFromMFile(size, link.MFile)}, nil
return GetRangeReaderFromMFile(size, link.MFile), nil
}
if link.Concurrency > 0 || link.PartSize > 0 {
down := net.NewDownloader(func(d *net.Downloader) {
@ -97,13 +97,16 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF,
return RateLimitRangeReaderFunc(rangeReader), nil
}
func GetRangeReaderFromMFile(size int64, file model.File) RangeReaderFunc {
return func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length
if length < 0 || httpRange.Start+length > size {
length = size - httpRange.Start
}
return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil
// RangeReaderIF.RangeRead返回的io.ReadCloser保留file的签名。
func GetRangeReaderFromMFile(size int64, file model.File) model.RangeReaderIF {
return &model.FileRangeReader{
RangeReaderIF: RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length
if length < 0 || httpRange.Start+length > size {
length = size - httpRange.Start
}
return &model.FileCloser{File: io.NewSectionReader(file, httpRange.Start, length)}, nil
}),
}
}
@ -227,11 +230,8 @@ func (ss *StreamSectionReader) GetSectionReader(off, length int64) (*SectionRead
tempBuf := ss.bufPool.Get().([]byte)
buf = tempBuf[:length]
n, err := io.ReadFull(ss.file, buf)
if err != nil {
return nil, err
}
if int64(n) != length {
return nil, fmt.Errorf("can't read data, expected=%d, got=%d", length, n)
return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err)
}
ss.off += int64(n)
off = 0