diff --git a/internal/partedio/nwriter.go b/internal/partedio/nwriter.go index 57971bc..080b496 100644 --- a/internal/partedio/nwriter.go +++ b/internal/partedio/nwriter.go @@ -38,7 +38,8 @@ func NewNWriter(partSize int64, concurrency int, handler PartHandler) (io.WriteC initOnce.Do(func() { pool = &sync.Pool{ New: func() interface{} { - return make([]byte, partSize) + buf := make([]byte, partSize) + return &buf }, } }) @@ -87,8 +88,9 @@ func (nw *NWriter) startWriting(src io.Reader) { go func() { defer nw.wg.Done() - buffer := pool.Get().([]byte) - defer pool.Put(buffer) + bufferPtr := pool.Get().(*[]byte) + defer pool.Put(bufferPtr) + buffer := *bufferPtr for nw.getErr() == nil { n, err := reader.Read(buffer) diff --git a/internal/partedio/nwriter_test.go b/internal/partedio/nwriter_test.go index f0a1ad6..ba14395 100644 --- a/internal/partedio/nwriter_test.go +++ b/internal/partedio/nwriter_test.go @@ -2,7 +2,6 @@ package partedio import ( "errors" - "io" "strings" "sync" "sync/atomic" @@ -20,22 +19,22 @@ func TestNewNWriter(t *testing.T) { }{ { name: "valid parameters", - partSize: 100, - handler: func(int, int64, io.Reader) error { return nil }, + partSize: 4, + handler: func(int, int64, []byte) error { return nil }, concurrency: 5, wantErr: false, }, { name: "zero part size", partSize: 0, - handler: func(int, int64, io.Reader) error { return nil }, + handler: func(int, int64, []byte) error { return nil }, concurrency: 5, wantErr: true, }, { name: "negative part size", partSize: -1, - handler: func(int, int64, io.Reader) error { return nil }, + handler: func(int, int64, []byte) error { return nil }, concurrency: 5, wantErr: true, }, @@ -76,7 +75,7 @@ func TestNWriterWrite(t *testing.T) { wantParts: 0, wantErr: false, handler: func(t *testing.T) PartHandler { - return func(partNum int, size int64, r io.Reader) error { + return func(partNum int, size int64, data []byte) error { t.Error("handler should not be called for empty input") return nil } @@ -90,11 +89,7 @@ func TestNWriterWrite(t *testing.T) { wantParts: 1, wantErr: false, handler: func(t *testing.T) PartHandler { - return func(partNum int, size int64, r io.Reader) error { - data, err := io.ReadAll(r) - if err != nil { - t.Errorf("failed to read part: %v", err) - } + return func(partNum int, size int64, data []byte) error { if string(data) != "test" { t.Errorf("part data = %s, want %s", string(data), "test") } @@ -112,11 +107,7 @@ func TestNWriterWrite(t *testing.T) { handler: func(t *testing.T) PartHandler { var mu sync.Mutex parts := make(map[int]string) - return func(partNum int, size int64, r io.Reader) error { - data, err := io.ReadAll(r) - if err != nil { - t.Errorf("failed to read part: %v", err) - } + return func(partNum int, size int64, data []byte) error { mu.Lock() parts[partNum] = string(data) mu.Unlock() @@ -132,7 +123,7 @@ func TestNWriterWrite(t *testing.T) { wantParts: 1, wantErr: true, handler: func(t *testing.T) PartHandler { - return func(partNum int, size int64, r io.Reader) error { + return func(partNum int, size int64, data []byte) error { return errors.New("handler error") } }, @@ -142,9 +133,9 @@ func TestNWriterWrite(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var processedParts int64 - handler := func(partNum int, size int64, r io.Reader) error { + handler := func(partNum int, size int64, data []byte) error { atomic.AddInt64(&processedParts, 1) - return tt.handler(t)(partNum, size, r) + return tt.handler(t)(partNum, size, data) } w, err := NewNWriter(tt.partSize, tt.concurrency, handler) @@ -179,7 +170,7 @@ func TestNWriterConcurrentWrites(t *testing.T) { var mu sync.Mutex processed := make(map[int]bool) - handler := func(partNum int, size int64, r io.Reader) error { + handler := func(partNum int, size int64, data []byte) error { current := atomic.AddInt32(¤tConcurrent, 1) defer atomic.AddInt32(¤tConcurrent, -1)