diff --git a/BUILD.bazel b/BUILD.bazel index 79614f8..2aefb8c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -27,6 +27,7 @@ go_library( "@com_github_pkg_errors//:go_default_library", "@com_github_ulikunitz_xz//:go_default_library", "@com_github_ulikunitz_xz//lzma:go_default_library", + "@com_github_klauspost_compress//zstd:go_default_library", ], ) diff --git a/deps.bzl b/deps.bzl index 68a51f3..d3a985f 100644 --- a/deps.bzl +++ b/deps.bzl @@ -33,3 +33,10 @@ def rpmpack_dependencies(): sum = "h1:YvTNdFzX6+W5m9msiYg/zpkSURPPtOlzbqYjrFn7Yt4=", version = "v0.5.7", ) + + go_repository( + name = "com_github_klauspost_compress", + importpath = "github.com/klauspost/compress", + sum = "h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc=", + version = "v1.13.6", + ) diff --git a/go.mod b/go.mod index 0d631d1..472585a 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.12 require ( github.com/cavaliercoder/go-cpio v0.0.0-20180626203310-925f9528c45e github.com/google/go-cmp v0.3.1 + github.com/klauspost/compress v1.13.6 github.com/pkg/errors v0.9.1 github.com/ulikunitz/xz v0.5.8 ) diff --git a/go.sum b/go.sum index 25b79d6..2905997 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/cavaliercoder/go-cpio v0.0.0-20180626203310-925f9528c45e h1:hHg27A0RS github.com/cavaliercoder/go-cpio v0.0.0-20180626203310-925f9528c45e/go.mod h1:oDpT4efm8tSYHXV5tHSdRvBet/b/QzxZ+XyyPehvm3A= github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/ulikunitz/xz v0.5.8 h1:ERv8V6GKqVi23rgu5cj9pVfVzJbOqAY2Ntl88O6c2nQ= diff --git a/rpm.go b/rpm.go index a0268ea..bbfdaaa 100644 --- a/rpm.go +++ b/rpm.go @@ -25,9 +25,12 @@ import ( "io" "path" "sort" + "strconv" + "strings" "time" cpio "github.com/cavaliercoder/go-cpio" + "github.com/klauspost/compress/zstd" "github.com/pkg/errors" "github.com/ulikunitz/xz" "github.com/ulikunitz/xz/lzma" @@ -110,24 +113,15 @@ func NewRPM(m RPMMetaData) (*RPM, error) { } p := &bytes.Buffer{} - var z io.WriteCloser - switch m.Compressor { - case "": - m.Compressor = "gzip" - fallthrough - case "gzip": - z, err = gzip.NewWriterLevel(p, 9) - case "lzma": - z, err = lzma.NewWriter(p) - case "xz": - z, err = xz.NewWriter(p) - default: - err = fmt.Errorf("unknown compressor type %s", m.Compressor) - } + + z, compressorName, err := setupCompressor(m.Compressor, p) if err != nil { - return nil, errors.Wrap(err, "failed to create compression writer") + return nil, err } + // only use compressor name for the rpm tag, not the level + m.Compressor = compressorName + rpm := &RPM{ RPMMetaData: m, di: newDirIndex(), @@ -149,6 +143,73 @@ func NewRPM(m RPMMetaData) (*RPM, error) { return rpm, nil } +func setupCompressor(compressorSetting string, w io.Writer) (wc io.WriteCloser, + compressorType string, err error) { + + parts := strings.Split(compressorSetting, ":") + if len(parts) > 2 { + return nil, "", fmt.Errorf("malformed compressor setting: %s", compressorSetting) + } + + compressorType = parts[0] + compressorLevel := "" + if len(parts) == 2 { + compressorLevel = parts[1] + } + + switch compressorType { + case "": + compressorType = "gzip" + fallthrough + case "gzip": + level := 9 + + if compressorLevel != "" { + var err error + + level, err = strconv.Atoi(compressorLevel) + if err != nil { + return nil, "", fmt.Errorf("parse gzip compressor level: %w", err) + } + } + + wc, err = gzip.NewWriterLevel(w, level) + case "lzma": + if compressorLevel != "" { + return nil, "", fmt.Errorf("no compressor level supported for lzma: %s", compressorLevel) + } + + wc, err = lzma.NewWriter(w) + case "xz": + if compressorLevel != "" { + return nil, "", fmt.Errorf("no compressor level supported for xz: %s", compressorLevel) + } + + wc, err = xz.NewWriter(w) + case "zstd": + level := zstd.SpeedBetterCompression + + if compressorLevel != "" { + var ok bool + + if intLevel, err := strconv.Atoi(compressorLevel); err == nil { + level = zstd.EncoderLevelFromZstd(intLevel) + } else { + ok, level = zstd.EncoderLevelFromString(compressorLevel) + if !ok { + return nil, "", fmt.Errorf("invalid zstd compressor level: %s", compressorLevel) + } + } + } + + wc, err = zstd.NewWriter(w, zstd.WithEncoderLevel(level)) + default: + return nil, "", fmt.Errorf("unknown compressor type: %s", compressorType) + } + + return wc, compressorType, err +} + // FullVersion properly combines version and release fields to a version string func (r *RPM) FullVersion() string { if r.Release != "" { @@ -218,7 +279,7 @@ func (r *RPM) Write(w io.Writer) error { if _, err := w.Write(sb); err != nil { return errors.Wrap(err, "failed to write signature bytes") } - //Signatures are padded to 8-byte boundaries + // Signatures are padded to 8-byte boundaries if _, err := w.Write(make([]byte, (8-len(sb)%8)%8)); err != nil { return errors.Wrap(err, "failed to write signature padding") } @@ -227,7 +288,6 @@ func (r *RPM) Write(w io.Writer) error { } _, err = w.Write(r.payload.Bytes()) return errors.Wrap(err, "failed to write payload") - } // SetPGPSigner registers a function that will accept the header and payload as bytes, diff --git a/rpm_test.go b/rpm_test.go index d8e27b3..16d54d0 100644 --- a/rpm_test.go +++ b/rpm_test.go @@ -1,8 +1,15 @@ package rpmpack import ( + "compress/gzip" + "io" "io/ioutil" + "reflect" "testing" + + "github.com/klauspost/compress/zstd" + "github.com/ulikunitz/xz" + "github.com/ulikunitz/xz/lzma" ) func TestFileOwner(t *testing.T) { @@ -52,5 +59,97 @@ func Test100644(t *testing.T) { if r.filelinktos[0] != "" { t.Errorf("linktos want empty (not a symlink), got %q", r.filelinktos[0]) } +} + +func TestCompression(t *testing.T) { + testCases := []struct { + Type string + Compressors []string + ExpectedWriter io.Writer + }{ + { + Type: "gzip", + Compressors: []string{ + "", "gzip", "gzip:1", "gzip:2", "gzip:3", + "gzip:4", "gzip:5", "gzip:6", "gzip:7", "gzip:8", "gzip:9", + }, + ExpectedWriter: &gzip.Writer{}, + }, + { + Type: "gzip", + Compressors: []string{"gzip:fast", "gzip:10"}, + ExpectedWriter: nil, // gzip requires an integer level from -2 to 9 + }, + { + Type: "lzma", + Compressors: []string{"lzma"}, + ExpectedWriter: &lzma.Writer{}, + }, + { + Type: "lzma", + Compressors: []string{"lzma:fast", "lzma:1"}, + ExpectedWriter: nil, // lzma does not support specifying the compression level + }, + { + Type: "xz", + Compressors: []string{"xz"}, + ExpectedWriter: &xz.Writer{}, + }, + { + Type: "xz", + Compressors: []string{"xz:fast", "xz:1"}, + ExpectedWriter: nil, // xz does not support specifying the compression level + }, + { + Type: "zstd", + Compressors: []string{ + "zstd", "zstd:fastest", "zstd:default", "zstd:better", + "zstd:best", "zstd:BeSt", "zstd:0", "zstd:4", "zstd:8", "zstd:15", + }, + ExpectedWriter: &zstd.Encoder{}, + }, + { + Type: "zstd", + Compressors: []string{"xz:worst"}, + ExpectedWriter: nil, // only integers levels or one of the pre-defined string values + }, + } + + for _, testCase := range testCases { + testCase := testCase + + for _, compressor := range testCase.Compressors { + t.Run(compressor, func(t *testing.T) { + r, err := NewRPM(RPMMetaData{ + Compressor: compressor, + }) + if err != nil { + if testCase.ExpectedWriter == nil { + return // an error is expected + } + t.Fatalf("NewRPM returned error %v", err) + } + + if testCase.ExpectedWriter == nil { + t.Fatalf("compressor %q should have produced an error", compressor) + } + + if r.RPMMetaData.Compressor != testCase.Type { + t.Fatalf("expected compressor %q, got %q", compressor, + r.RPMMetaData.Compressor) + } + + expectedWriterType := reflect.Indirect(reflect.ValueOf( + testCase.ExpectedWriter)).String() + actualWriterType := reflect.Indirect(reflect.ValueOf( + r.compressedPayload)).String() + + if expectedWriterType != actualWriterType { + t.Fatalf("expected writer to be %T, got %T instead", + testCase.ExpectedWriter, r.compressedPayload) + } + }) + } + } }