diff --git a/cmd/bootstrap/main.go b/cmd/bootstrap/main.go index 3445a86..516b6cd 100644 --- a/cmd/bootstrap/main.go +++ b/cmd/bootstrap/main.go @@ -99,7 +99,7 @@ func run(ctx context.Context, cfg *config.Config, logger *zap.Logger) error { strings.TrimSpace(cfg.Chain.ChainData.ChainID), cfg.Chain.ChainData.Database, ) - result, err := snapDl.SyncIfNeeded(ctx, cfg.Chain.SnapshotURL, chainDest) + result, err := snapDl.SyncIfNeeded(ctx, cfg.Chain.SnapshotURL, chainDest, strings.TrimSpace(cfg.Chain.ChainData.ChainID)) if err != nil { return fmt.Errorf("chain snapshot: %w", err) } @@ -110,7 +110,7 @@ func run(ctx context.Context, cfg *config.Config, logger *zap.Logger) error { strings.TrimSpace(cfg.RelayChain.ChainData.ChainID), cfg.RelayChain.ChainData.Database, ) - result, err := snapDl.SyncIfNeeded(ctx, cfg.RelayChain.SnapshotURL, relayDest) + result, err := snapDl.SyncIfNeeded(ctx, cfg.RelayChain.SnapshotURL, relayDest, strings.TrimSpace(cfg.RelayChain.ChainData.ChainID)) if err != nil { return fmt.Errorf("relay chain snapshot: %w", err) } diff --git a/internal/config/config.go b/internal/config/config.go index 57f20f1..4777993 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,7 +29,7 @@ func RelayChainDataPath() string { return filepath.Join(DataDir(), "relaychain-d func ChainspecPath() string { return filepath.Join(ChainDataPath(), "chainspec.json") } func RelayChainspecPath() string { return filepath.Join(RelayChainDataPath(), "chainspec.json") } -// DatabaseStorageDir returns the per-chain database directory name under chains//. +// DatabaseStorageDir returns the per-chain database directory name under chains//. // Matches Parity helm node.databasePath: paritydb -> "paritydb", else "db" (rocksdb). func DatabaseStorageDir(database string) string { if strings.EqualFold(strings.TrimSpace(database), "paritydb") { @@ -38,15 +38,21 @@ func DatabaseStorageDir(database string) string { return "db" } +// SubstrateChainsDirName returns the directory segment under base-path/chains/ for chain_id. +// Substrate normalizes hyphens to underscores (e.g. avn-paseo-v2 -> avn_paseo_v2), matching on-disk layout. +func SubstrateChainsDirName(chainID string) string { + return strings.ReplaceAll(chainID, "-", "_") +} + // ChainDBDataPath returns the chain snapshot / DB path: -// base-path/chains/// +// base-path/chains/// func ChainDBDataPath(chainID, database string) string { - return filepath.Join(ChainDataPath(), "chains", chainID, DatabaseStorageDir(database)) + return filepath.Join(ChainDataPath(), "chains", SubstrateChainsDirName(chainID), DatabaseStorageDir(database)) } // RelayChainDBDataPath returns the relay chain snapshot / DB path. func RelayChainDBDataPath(chainID, database string) string { - return filepath.Join(RelayChainDataPath(), "chains", chainID, DatabaseStorageDir(database)) + return filepath.Join(RelayChainDataPath(), "chains", SubstrateChainsDirName(chainID), DatabaseStorageDir(database)) } func KeystorePath() string { return filepath.Join(DataDir(), "keystore") } @@ -73,7 +79,7 @@ type NodeConfig struct { // ChainDataConfig mirrors Parity node chart chainData (database backend + chain id directory segment). type ChainDataConfig struct { Database string `yaml:"database"` // rocksdb (default) or paritydb - ChainID string `yaml:"chain_id"` // Substrate chains// segment; not a full path + ChainID string `yaml:"chain_id"` // Logical chain id; on disk under chains/ uses SubstrateChainsDirName (hyphens -> underscores) } // ChainConfig holds chain-specific settings. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 35760dc..5cfb37e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -125,6 +125,7 @@ func TestFixedPaths(t *testing.T) { assert.Equal(t, "paritydb", DatabaseStorageDir("paritydb")) assert.Equal(t, "db", DatabaseStorageDir("rocksdb")) assert.Equal(t, "/data/chain-data/chains/avn_staging_dev_testnet/paritydb", ChainDBDataPath("avn_staging_dev_testnet", "paritydb")) + assert.Equal(t, "/data/chain-data/chains/avn_paseo_v2/db", ChainDBDataPath("avn-paseo-v2", "rocksdb")) assert.Equal(t, "/data/chain-data/chains/foo/db", ChainDBDataPath("foo", "rocksdb")) assert.Equal(t, "/data/relaychain-data/chains/paseo/paritydb", RelayChainDBDataPath("paseo", "paritydb")) assert.Equal(t, "/data/keystore", KeystorePath()) diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 122f6ef..008d17f 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "os/exec" + "path" "path/filepath" "regexp" "runtime" @@ -18,6 +19,7 @@ import ( "strings" "time" + "github.com/AventusDAO/substrate-bootstrap/internal/config" "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" "github.com/ulikunitz/xz" @@ -68,7 +70,8 @@ type SyncResult struct { // - All other URLs (e.g. Polkadot snapshots.polkadot.io) use rclone with a files.txt manifest. // // For Polkadot-style base URLs (no version suffix), fetches latest_version.meta.txt to resolve latest snapshot. -func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath string) (*SyncResult, error) { +// configChainID is the YAML chain_id; used for tar member path rewriting (may differ from the on-disk chains/ segment). +func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath, configChainID string) (*SyncResult, error) { if snapshotURL == "" { return nil, nil } @@ -101,7 +104,7 @@ func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath str if isTarURL(resolvedURL) { result.Method = "tar" - err = d.downloadAndExtractTar(ctx, resolvedURL, dataPath) + err = d.downloadAndExtractTar(ctx, resolvedURL, dataPath, configChainID) } else { result.Method = "rclone" err = d.downloadWithRclone(ctx, resolvedURL, dataPath) @@ -314,7 +317,7 @@ func (d *Downloader) downloadWithRclone(ctx context.Context, snapshotURL, destPa return nil } -func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath string) error { +func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath, configChainID string) error { d.logger.Info("streaming snapshot via tar", zap.String("dest", destPath)) @@ -346,7 +349,7 @@ func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath st } defer decomp.Close() - if err := extractTarSecure(destPath, decomp); err != nil { + if err := extractTarSecure(destPath, decomp, configChainID); err != nil { return fmt.Errorf("extracting snapshot: %w", err) } @@ -417,7 +420,87 @@ func newDecompressorMaybeCompressedTar(r io.Reader) (io.ReadCloser, error) { return io.NopCloser(br), nil } -func extractTarSecure(destPath string, r io.Reader) error { +// mapTarEntryToDestRel strips a leading chains/// when member paths mirror the +// snapshot layout. destPath uses config.SubstrateChainsDirName; configChainID is the YAML chain_id +// so hyphenated archive paths (chains/avn-paseo-v2/...) still match. +func mapTarEntryToDestRel(archiveName, destPath, configChainID string) (rel string, mapped bool) { + storageDir := filepath.Base(destPath) + chainDir := filepath.Dir(destPath) + chainSeg := filepath.Base(chainDir) + + norm := path.Clean(strings.ReplaceAll(archiveName, `\`, `/`)) + if norm == "." || norm == "/" { + return archiveName, false + } + + seen := make(map[string]struct{}) + var candidates []string + add := func(s string) { + if _, ok := seen[s]; ok { + return + } + seen[s] = struct{}{} + candidates = append(candidates, s) + } + + add(path.Join("chains", chainSeg, storageDir)) + if configChainID != "" { + add(path.Join("chains", configChainID, storageDir)) + normalized := config.SubstrateChainsDirName(configChainID) + if normalized != configChainID { + add(path.Join("chains", normalized, storageDir)) + } + } + + for _, c := range candidates { + if norm == c { + return ".", true + } + prefix := c + "/" + if strings.HasPrefix(norm, prefix) { + return norm[len(prefix):], true + } + } + return archiveName, false +} + +func normalizeMappedTarRel(rel string) string { + if rel == "." { + return "." + } + return filepath.FromSlash(path.Clean(rel)) +} + +func validateTarRelUnderDest(destPath, rel string) error { + if filepath.IsAbs(rel) { + return fmt.Errorf("rejecting absolute path in archive: %s", rel) + } + clean := filepath.Clean(rel) + if clean == "." { + return nil + } + target := filepath.Join(destPath, clean) + out, err := filepath.Rel(destPath, target) + if err != nil || strings.HasPrefix(out, "..") || out == ".." { + return fmt.Errorf("rejecting path outside destination: %s", rel) + } + return nil +} + +func validateSymlinkTargetForRel(destPath, linkRel, linkname string) error { + if filepath.IsAbs(linkname) { + return fmt.Errorf("rejecting absolute symlink target: %s -> %s", linkRel, linkname) + } + linkDir := filepath.Dir(filepath.Join(destPath, linkRel)) + resolved := filepath.Clean(filepath.Join(linkDir, linkname)) + rel, err := filepath.Rel(destPath, resolved) + if err != nil || strings.HasPrefix(rel, "..") || rel == ".." { + return fmt.Errorf("rejecting symlink escaping destination: %s -> %s", linkRel, linkname) + } + return nil +} + +func extractTarSecure(destPath string, r io.Reader, configChainID string) error { absDest, err := filepath.Abs(destPath) if err != nil { return fmt.Errorf("resolving dest path: %w", err) @@ -433,11 +516,21 @@ func extractTarSecure(destPath string, r io.Reader) error { return fmt.Errorf("reading tar: %w", err) } - if err := validateTarPath(absDest, hdr); err != nil { - return err + mappedRel, mapped := mapTarEntryToDestRel(hdr.Name, absDest, configChainID) + var joinRel string + if mapped { + joinRel = normalizeMappedTarRel(mappedRel) + if err := validateTarRelUnderDest(absDest, joinRel); err != nil { + return err + } + } else { + if err := validateTarPath(absDest, hdr); err != nil { + return err + } + joinRel = filepath.Clean(hdr.Name) } - target := filepath.Join(absDest, filepath.Clean(hdr.Name)) + target := filepath.Join(absDest, joinRel) switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(target, 0o750); err != nil { @@ -451,8 +544,14 @@ func extractTarSecure(destPath string, r io.Reader) error { return err } case tar.TypeSymlink: - if err := validateSymlinkTarget(absDest, hdr.Name, hdr.Linkname); err != nil { - return err + if mapped { + if err := validateSymlinkTargetForRel(absDest, joinRel, hdr.Linkname); err != nil { + return err + } + } else { + if err := validateSymlinkTarget(absDest, hdr.Name, hdr.Linkname); err != nil { + return err + } } if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil { return fmt.Errorf("creating parent for symlink %s: %w", hdr.Name, err) diff --git a/internal/snapshot/snapshot_test.go b/internal/snapshot/snapshot_test.go index 95a3d20..5bb4870 100644 --- a/internal/snapshot/snapshot_test.go +++ b/internal/snapshot/snapshot_test.go @@ -20,6 +20,53 @@ import ( "go.uber.org/zap" ) +func TestMapTarEntryToDestRel_UnderscoreChainInArchive(t *testing.T) { + dest := filepath.Join("/data", "chain-data", "chains", "avn_paseo_v2", "db") + cfgID := "avn-paseo-v2" + + rel, mapped := mapTarEntryToDestRel("chains/avn_paseo_v2/db/rocksdb/MANIFEST", dest, cfgID) + require.True(t, mapped) + assert.Equal(t, "rocksdb/MANIFEST", rel) + + rel, mapped = mapTarEntryToDestRel("chains/avn_paseo_v2/db", dest, cfgID) + require.True(t, mapped) + assert.Equal(t, ".", rel) + + rel, mapped = mapTarEntryToDestRel("chains/avn-paseo-v2/db/sub/file", dest, cfgID) + require.True(t, mapped) + assert.Equal(t, "sub/file", rel) +} + +func TestMapTarEntryToDestRel_UnchangedWhenNoPrefixMatch(t *testing.T) { + dest := filepath.Join(t.TempDir(), "chains", "mychain", "db") + name := "probe.txt" + rel, mapped := mapTarEntryToDestRel(name, dest, "") + assert.False(t, mapped) + assert.Equal(t, name, rel) +} + +func TestSyncIfNeeded_TarHyphenChainIDUnderscorePathsInArchive(t *testing.T) { + d := testDownloader(t) + chainRoot := filepath.Join(t.TempDir(), "chain-data", "chains", "avn_paseo_v2", "db") + _ = os.MkdirAll(chainRoot, 0o750) + + files := map[string]string{ + "chains/avn_paseo_v2/db/ready.txt": "ok", + } + server := createTarGzServer(t, files) + defer server.Close() + + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snapshot.tar.gz", chainRoot, "avn-paseo-v2") + require.NoError(t, err) + assert.True(t, result.Downloaded) + + data, err := os.ReadFile(filepath.Join(chainRoot, "ready.txt")) + require.NoError(t, err) + assert.Equal(t, "ok", string(data)) + _, err = os.Stat(filepath.Join(chainRoot, "chains")) + require.Error(t, err, "archive prefix must not create a nested chains/ under db") +} + func testDownloader(t *testing.T) *Downloader { t.Helper() logger, err := zap.NewDevelopment() @@ -51,7 +98,7 @@ func createTarGzServer(t *testing.T, files map[string]string) *httptest.Server { func TestSyncIfNeeded_EmptyURL(t *testing.T) { d := testDownloader(t) - result, err := d.SyncIfNeeded(context.Background(), "", "/any/path") + result, err := d.SyncIfNeeded(context.Background(), "", "/any/path", "") require.NoError(t, err) assert.Nil(t, result) } @@ -64,7 +111,7 @@ func TestSyncIfNeeded_AlreadyHasData(t *testing.T) { server := createTarGzServer(t, map[string]string{"test.txt": "should not download"}) defer server.Close() - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir, "") require.NoError(t, err) assert.True(t, result.Skipped) assert.False(t, result.Downloaded) @@ -84,7 +131,7 @@ func TestSyncIfNeeded_DownloadsAndExtracts(t *testing.T) { }) defer server.Close() - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir, "") require.NoError(t, err) assert.True(t, result.Downloaded) assert.False(t, result.Skipped) @@ -107,7 +154,7 @@ func TestSyncIfNeeded_TarExtensionDetectsGzip(t *testing.T) { server := createTarGzServer(t, map[string]string{"db/metadata": "gzipped"}) defer server.Close() - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/data.tar", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/data.tar", dir, "") require.NoError(t, err) assert.True(t, result.Downloaded) assert.Equal(t, "tar", result.Method) @@ -190,7 +237,7 @@ func TestSyncIfNeeded_TarExtensionDetectsCompressedMagic(t *testing.T) { d := testDownloader(t) dir := filepath.Join(t.TempDir(), "chaindata") - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snapshot.tar", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snapshot.tar", dir, "") require.NoError(t, err) assert.True(t, result.Downloaded) assert.Equal(t, "tar", result.Method) @@ -226,7 +273,7 @@ func TestSyncIfNeeded_TarExtensionUncompressedTar(t *testing.T) { server := createRawTarServer(t, map[string]string{"plain.txt": "raw"}) defer server.Close() - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snapshot.tar", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snapshot.tar", dir, "") require.NoError(t, err) assert.True(t, result.Downloaded) assert.Equal(t, "tar", result.Method) @@ -243,7 +290,7 @@ func TestSyncIfNeeded_NonexistentDir(t *testing.T) { server := createTarGzServer(t, map[string]string{"data.txt": "hello"}) defer server.Close() - result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir) + result, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir, "") require.NoError(t, err) assert.True(t, result.Downloaded) @@ -261,7 +308,7 @@ func TestSyncIfNeeded_HTTPError(t *testing.T) { })) defer server.Close() - _, err := d.SyncIfNeeded(context.Background(), server.URL+"/missing.tar.gz", dir) + _, err := d.SyncIfNeeded(context.Background(), server.URL+"/missing.tar.gz", dir, "") require.Error(t, err) assert.Contains(t, err.Error(), "status 404") } @@ -270,7 +317,7 @@ func TestSyncIfNeeded_InvalidURL(t *testing.T) { d := testDownloader(t) dir := filepath.Join(t.TempDir(), "chaindata") - _, err := d.SyncIfNeeded(context.Background(), "http://localhost:1/definitely-not-running.tar.gz", dir) + _, err := d.SyncIfNeeded(context.Background(), "http://localhost:1/definitely-not-running.tar.gz", dir, "") require.Error(t, err) } @@ -286,7 +333,7 @@ func TestSyncIfNeeded_ContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := d.SyncIfNeeded(ctx, server.URL+"/snap.tar.gz", dir) + _, err := d.SyncIfNeeded(ctx, server.URL+"/snap.tar.gz", dir, "") require.Error(t, err) } @@ -498,7 +545,7 @@ func TestExtractTarSecure_RejectsPathTraversal(t *testing.T) { defer malicious.Close() d := testDownloader(t) - _, err := d.SyncIfNeeded(context.Background(), malicious.URL+"/evil.tar.gz", dir) + _, err := d.SyncIfNeeded(context.Background(), malicious.URL+"/evil.tar.gz", dir, "") require.Error(t, err) assert.Contains(t, err.Error(), "rejecting path outside destination") } @@ -517,7 +564,7 @@ func TestExtractTarSecure_AcceptsSymlinkWithDoubleDotsInName(t *testing.T) { defer server.Close() d := testDownloader(t) - _, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir) + _, err := d.SyncIfNeeded(context.Background(), server.URL+"/snap.tar.gz", dir, "") require.NoError(t, err) // Symlink should exist (target "foo..bar" resolves within dir, so it's valid) @@ -542,7 +589,7 @@ func TestExtractTarSecure_RejectsAbsolutePath(t *testing.T) { defer server.Close() d := testDownloader(t) - _, err := d.SyncIfNeeded(context.Background(), server.URL+"/evil.tar.gz", dir) + _, err := d.SyncIfNeeded(context.Background(), server.URL+"/evil.tar.gz", dir, "") require.Error(t, err) assert.Contains(t, err.Error(), "absolute path") }