package storage import ( "bytes" "fmt" "io" "sync" "github.com/klauspost/compress/zstd" ) /* zstd compression wrapper for larc blob storage. * Uses encoder/decoder pools for efficiency. * Default compression level balances speed and ratio. */ const ( // DefaultCompressionLevel is the default zstd level (3 = fast, good ratio) DefaultCompressionLevel = 3 ) var ( encoderPool sync.Pool decoderPool sync.Pool initOnce sync.Once ) func initPools() { initOnce.Do(func() { encoderPool = sync.Pool{ New: func() any { enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(DefaultCompressionLevel)), zstd.WithEncoderConcurrency(1), ) if err != nil { panic(fmt.Sprintf("failed to create zstd encoder: %v", err)) } return enc }, } decoderPool = sync.Pool{ New: func() any { dec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1), ) if err != nil { panic(fmt.Sprintf("failed to create zstd decoder: %v", err)) } return dec }, } }) } // Compress compresses data using zstd func Compress(data []byte) ([]byte, error) { initPools() enc := encoderPool.Get().(*zstd.Encoder) defer encoderPool.Put(enc) var buf bytes.Buffer enc.Reset(&buf) if _, err := enc.Write(data); err != nil { return nil, fmt.Errorf("zstd compress write: %w", err) } if err := enc.Close(); err != nil { return nil, fmt.Errorf("zstd compress close: %w", err) } return buf.Bytes(), nil } // Decompress decompresses zstd data func Decompress(data []byte) ([]byte, error) { initPools() dec := decoderPool.Get().(*zstd.Decoder) defer decoderPool.Put(dec) if err := dec.Reset(bytes.NewReader(data)); err != nil { return nil, fmt.Errorf("zstd decompress reset: %w", err) } result, err := io.ReadAll(dec) if err != nil { return nil, fmt.Errorf("zstd decompress read: %w", err) } return result, nil } // CompressReader compresses from reader to writer func CompressReader(r io.Reader, w io.Writer) (int64, error) { initPools() enc := encoderPool.Get().(*zstd.Encoder) defer encoderPool.Put(enc) enc.Reset(w) n, err := io.Copy(enc, r) if err != nil { return n, fmt.Errorf("zstd compress copy: %w", err) } if err := enc.Close(); err != nil { return n, fmt.Errorf("zstd compress close: %w", err) } return n, nil } // DecompressReader decompresses from reader to writer func DecompressReader(r io.Reader, w io.Writer) (int64, error) { initPools() dec := decoderPool.Get().(*zstd.Decoder) defer decoderPool.Put(dec) if err := dec.Reset(r); err != nil { return 0, fmt.Errorf("zstd decompress reset: %w", err) } n, err := io.Copy(w, dec) if err != nil { return n, fmt.Errorf("zstd decompress copy: %w", err) } return n, nil } // CompressLevel compresses with specific compression level (1-19) func CompressLevel(data []byte, level int) ([]byte, error) { enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(level)), ) if err != nil { return nil, fmt.Errorf("zstd new encoder: %w", err) } defer enc.Close() var buf bytes.Buffer enc.Reset(&buf) if _, err := enc.Write(data); err != nil { return nil, fmt.Errorf("zstd compress write: %w", err) } if err := enc.Close(); err != nil { return nil, fmt.Errorf("zstd compress close: %w", err) } return buf.Bytes(), nil }