diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index fc95007d..39bc3322 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -785,8 +785,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo // step.4 上传切片 uploadUrl := uploadUrls[0] - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, - driver.NewLimitedUploadStream(ctx, rateLimitedRd), isFamily) + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, rateLimitedRd, isFamily) if err != nil { return err } diff --git a/internal/op/fs.go b/internal/op/fs.go index 114c26fc..c5a5b52d 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -630,6 +630,11 @@ func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file mod up = func(p float64) {} } + // 如果小于0,则通过缓存获取完整大小,可能发生于流式上传 + if file.GetSize() < 0 { + log.Warnf("file size < 0, try to get full size from cache") + file.CacheFullAndWriter(nil, nil) + } switch s := storage.(type) { case driver.PutResult: var newObj model.Obj diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 94772761..8d2f504f 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -137,6 +137,60 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ if writer != nil { reader = io.TeeReader(reader, writer) } + + if f.GetSize() < 0 { + if f.peekBuff == nil { + f.peekBuff = &buffer.Reader{} + } + // 检查是否有数据 + buf := []byte{0} + n, err := io.ReadFull(reader, buf) + if n > 0 { + f.peekBuff.Append(buf[:n]) + } + if err == io.ErrUnexpectedEOF { + f.size = f.peekBuff.Size() + f.Reader = f.peekBuff + return f.peekBuff, nil + } else if err != nil { + return nil, err + } + if conf.MaxBufferLimit-n > conf.MmapThreshold && conf.MmapThreshold > 0 { + m, err := mmap.Alloc(conf.MaxBufferLimit - n) + if err == nil { + f.Add(utils.CloseFunc(func() error { + return mmap.Free(m) + })) + n, err = io.ReadFull(reader, m) + if n > 0 { + f.peekBuff.Append(m[:n]) + } + if err == io.ErrUnexpectedEOF { + f.size = f.peekBuff.Size() + f.Reader = f.peekBuff + return f.peekBuff, nil + } else if err != nil { + return nil, err + } + } + } + + tmpF, err := utils.CreateTempFile(reader, 0) + if err != nil { + return nil, err + } + f.Add(utils.CloseFunc(func() error { + return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) + })) + peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) + if err != nil { + return nil, err + } + f.size = peekF.Size() + f.Reader = peekF + return peekF, nil + } + f.Reader = reader return f.cache(f.GetSize()) } @@ -162,7 +216,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { } size := httpRange.Start + httpRange.Length - if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { + if f.peekBuff != nil && size <= int64(f.peekBuff.Size()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } @@ -194,7 +248,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { f.peekBuff = &buffer.Reader{} f.oriReader = f.Reader } - bufSize := maxCacheSize - int64(f.peekBuff.Len()) + bufSize := maxCacheSize - int64(f.peekBuff.Size()) var buf []byte if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) { m, err := mmap.Alloc(int(bufSize)) @@ -213,7 +267,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) } f.peekBuff.Append(buf) - if int64(f.peekBuff.Len()) >= f.GetSize() { + if int64(f.peekBuff.Size()) >= f.GetSize() { f.Reader = f.peekBuff f.oriReader = nil } else { diff --git a/pkg/buffer/bytes.go b/pkg/buffer/bytes.go index 3ee10747..3e6cb540 100644 --- a/pkg/buffer/bytes.go +++ b/pkg/buffer/bytes.go @@ -8,83 +8,86 @@ import ( // 用于存储不复用的[]byte type Reader struct { bufs [][]byte - length int - offset int + size int64 + offset int64 } -func (r *Reader) Len() int { - return r.length +func (r *Reader) Size() int64 { + return r.size } func (r *Reader) Append(buf []byte) { - r.length += len(buf) + r.size += int64(len(buf)) r.bufs = append(r.bufs, buf) } func (r *Reader) Read(p []byte) (int, error) { - n, err := r.ReadAt(p, int64(r.offset)) + n, err := r.ReadAt(p, r.offset) if n > 0 { - r.offset += n + r.offset += int64(n) } return n, err } func (r *Reader) ReadAt(p []byte, off int64) (int, error) { - if off < 0 || off >= int64(r.length) { + if off < 0 || off >= r.size { return 0, io.EOF } - n, length := 0, int64(0) + n := 0 readFrom := false for _, buf := range r.bufs { - newLength := length + int64(len(buf)) if readFrom { - w := copy(p[n:], buf) - n += w - } else if off < newLength { + nn := copy(p[n:], buf) + n += nn + if n == len(p) { + return n, nil + } + } else if newOff := off - int64(len(buf)); newOff >= 0 { + off = newOff + } else { + nn := copy(p, buf[off:]) + if nn == len(p) { + return nn, nil + } + n += nn readFrom = true - w := copy(p[n:], buf[int(off-length):]) - n += w } - if n == len(p) { - return n, nil - } - length = newLength } return n, io.EOF } func (r *Reader) Seek(offset int64, whence int) (int64, error) { - var abs int switch whence { case io.SeekStart: - abs = int(offset) case io.SeekCurrent: - abs = r.offset + int(offset) + offset = r.offset + offset case io.SeekEnd: - abs = r.length + int(offset) + offset = r.size + offset default: return 0, errors.New("Seek: invalid whence") } - if abs < 0 || abs > r.length { + if offset < 0 || offset > r.size { return 0, errors.New("Seek: invalid offset") } - r.offset = abs - return int64(abs), nil + r.offset = offset + return offset, nil } func (r *Reader) Reset() { clear(r.bufs) r.bufs = nil - r.length = 0 + r.size = 0 r.offset = 0 } func NewReader(buf ...[]byte) *Reader { - b := &Reader{} + b := &Reader{ + bufs: make([][]byte, 0, len(buf)), + } for _, b1 := range buf { b.Append(b1) } diff --git a/pkg/buffer/bytes_test.go b/pkg/buffer/bytes_test.go index b66af229..3f4d8556 100644 --- a/pkg/buffer/bytes_test.go +++ b/pkg/buffer/bytes_test.go @@ -13,8 +13,7 @@ func TestReader_ReadAt(t *testing.T) { } bs := &Reader{} bs.Append([]byte("github.com")) - bs.Append([]byte("/")) - bs.Append([]byte("OpenList")) + bs.Append([]byte("/OpenList")) bs.Append([]byte("Team/")) bs.Append([]byte("OpenList")) tests := []struct { @@ -71,7 +70,7 @@ func TestReader_ReadAt(t *testing.T) { off: 24, }, want: func(a args, n int, err error) error { - if n != bs.Len()-int(a.off) { + if n != int(bs.Size()-a.off) { return errors.New("read length not match") } if string(a.p[:n]) != "OpenList" { diff --git a/pkg/buffer/file.go b/pkg/buffer/file.go new file mode 100644 index 00000000..48edf5a4 --- /dev/null +++ b/pkg/buffer/file.go @@ -0,0 +1,88 @@ +package buffer + +import ( + "errors" + "io" + "os" +) + +type PeekFile struct { + peek *Reader + file *os.File + offset int64 + size int64 +} + +func (p *PeekFile) Read(b []byte) (n int, err error) { + n, err = p.ReadAt(b, p.offset) + if n > 0 { + p.offset += int64(n) + } + return n, err +} + +func (p *PeekFile) ReadAt(b []byte, off int64) (n int, err error) { + if off < p.peek.Size() { + n, err = p.peek.ReadAt(b, off) + if err == nil || n == len(b) { + return n, nil + } + // EOF + } + var nn int + nn, err = p.file.ReadAt(b[n:], off+int64(n)-p.peek.Size()) + return n + nn, err +} + +func (p *PeekFile) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + case io.SeekCurrent: + if offset == 0 { + return p.offset, nil + } + offset = p.offset + offset + case io.SeekEnd: + offset = p.size + offset + default: + return 0, errors.New("Seek: invalid whence") + } + + if offset < 0 || offset > p.size { + return 0, errors.New("Seek: invalid offset") + } + if offset <= p.peek.Size() { + _, err := p.peek.Seek(offset, io.SeekStart) + if err != nil { + return 0, err + } + _, err = p.file.Seek(0, io.SeekStart) + if err != nil { + return 0, err + } + } else { + _, err := p.peek.Seek(p.peek.Size(), io.SeekStart) + if err != nil { + return 0, err + } + _, err = p.file.Seek(offset-p.peek.Size(), io.SeekStart) + if err != nil { + return 0, err + } + } + + p.offset = offset + return offset, nil +} + +func (p *PeekFile) Size() int64 { + return p.size +} + +func NewPeekFile(peek *Reader, file *os.File) (*PeekFile, error) { + stat, err := file.Stat() + if err == nil { + return &PeekFile{peek: peek, file: file, size: stat.Size() + peek.Size()}, nil + } + return nil, err +} diff --git a/server/handles/fsup.go b/server/handles/fsup.go index 087a58a9..71d9dbae 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -56,14 +56,17 @@ func FsStream(c *gin.Context) { } } dir, name := stdpath.Split(path) - sizeStr := c.GetHeader("Content-Length") - if sizeStr == "" { - sizeStr = "0" - } - size, err := strconv.ParseInt(sizeStr, 10, 64) - if err != nil { - common.ErrorResp(c, err, 400) - return + // 如果请求头 Content-Length 和 X-File-Size 都没有,则 size=-1,表示未知大小的流式上传 + size := c.Request.ContentLength + if size < 0 { + sizeStr := c.GetHeader("X-File-Size") + if sizeStr != "" { + size, err = strconv.ParseInt(sizeStr, 10, 64) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + } } h := make(map[*utils.HashType]string) if md5 := c.GetHeader("X-File-Md5"); md5 != "" { diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index b6f7cdac..802947eb 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "path" + "strconv" "strings" "time" @@ -341,9 +342,19 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, if err != nil { return http.StatusForbidden, err } + size := r.ContentLength + if size < 0 { + sizeStr := r.Header.Get("X-File-Size") + if sizeStr != "" { + size, err = strconv.ParseInt(sizeStr, 10, 64) + if err != nil { + return http.StatusBadRequest, err + } + } + } obj := model.Object{ Name: path.Base(reqPath), - Size: r.ContentLength, + Size: size, Modified: h.getModTime(r), Ctime: h.getCreateTime(r), }