Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/bootstrap/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func run(ctx context.Context, cfg *config.Config, logger *zap.Logger) error {
strings.TrimSpace(cfg.Chain.ChainData.ChainID),
cfg.Chain.ChainData.Database,
)
result, err := snapDl.SyncIfNeeded(ctx, cfg.Chain.SnapshotURL, chainDest)
result, err := snapDl.SyncIfNeeded(ctx, cfg.Chain.SnapshotURL, chainDest, strings.TrimSpace(cfg.Chain.ChainData.ChainID))
if err != nil {
return fmt.Errorf("chain snapshot: %w", err)
}
Expand All @@ -110,7 +110,7 @@ func run(ctx context.Context, cfg *config.Config, logger *zap.Logger) error {
strings.TrimSpace(cfg.RelayChain.ChainData.ChainID),
cfg.RelayChain.ChainData.Database,
)
result, err := snapDl.SyncIfNeeded(ctx, cfg.RelayChain.SnapshotURL, relayDest)
result, err := snapDl.SyncIfNeeded(ctx, cfg.RelayChain.SnapshotURL, relayDest, strings.TrimSpace(cfg.RelayChain.ChainData.ChainID))
if err != nil {
return fmt.Errorf("relay chain snapshot: %w", err)
}
Expand Down
16 changes: 11 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func RelayChainDataPath() string { return filepath.Join(DataDir(), "relaychain-d
func ChainspecPath() string { return filepath.Join(ChainDataPath(), "chainspec.json") }
func RelayChainspecPath() string { return filepath.Join(RelayChainDataPath(), "chainspec.json") }

// DatabaseStorageDir returns the per-chain database directory name under chains/<chain_id>/.
// DatabaseStorageDir returns the per-chain database directory name under chains/<chain_dir>/.
// Matches Parity helm node.databasePath: paritydb -> "paritydb", else "db" (rocksdb).
func DatabaseStorageDir(database string) string {
if strings.EqualFold(strings.TrimSpace(database), "paritydb") {
Expand All @@ -38,15 +38,21 @@ func DatabaseStorageDir(database string) string {
return "db"
}

// SubstrateChainsDirName returns the directory segment under base-path/chains/ for chain_id.
// Substrate normalizes hyphens to underscores (e.g. avn-paseo-v2 -> avn_paseo_v2), matching on-disk layout.
func SubstrateChainsDirName(chainID string) string {
return strings.ReplaceAll(chainID, "-", "_")
}

// ChainDBDataPath returns the chain snapshot / DB path:
// base-path/chains/<chainID>/<storageDir>/
// base-path/chains/<SubstrateChainsDirName(chainID)>/<storageDir>/
func ChainDBDataPath(chainID, database string) string {
return filepath.Join(ChainDataPath(), "chains", chainID, DatabaseStorageDir(database))
return filepath.Join(ChainDataPath(), "chains", SubstrateChainsDirName(chainID), DatabaseStorageDir(database))
}

// RelayChainDBDataPath returns the relay chain snapshot / DB path.
func RelayChainDBDataPath(chainID, database string) string {
return filepath.Join(RelayChainDataPath(), "chains", chainID, DatabaseStorageDir(database))
return filepath.Join(RelayChainDataPath(), "chains", SubstrateChainsDirName(chainID), DatabaseStorageDir(database))
}

func KeystorePath() string { return filepath.Join(DataDir(), "keystore") }
Expand All @@ -73,7 +79,7 @@ type NodeConfig struct {
// ChainDataConfig mirrors Parity node chart chainData (database backend + chain id directory segment).
type ChainDataConfig struct {
Database string `yaml:"database"` // rocksdb (default) or paritydb
ChainID string `yaml:"chain_id"` // Substrate chains/<chain_id>/ segment; not a full path
ChainID string `yaml:"chain_id"` // Logical chain id; on disk under chains/ uses SubstrateChainsDirName (hyphens -> underscores)
Comment thread
vukomir marked this conversation as resolved.
}

// ChainConfig holds chain-specific settings.
Expand Down
1 change: 1 addition & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func TestFixedPaths(t *testing.T) {
assert.Equal(t, "paritydb", DatabaseStorageDir("paritydb"))
assert.Equal(t, "db", DatabaseStorageDir("rocksdb"))
assert.Equal(t, "/data/chain-data/chains/avn_staging_dev_testnet/paritydb", ChainDBDataPath("avn_staging_dev_testnet", "paritydb"))
assert.Equal(t, "/data/chain-data/chains/avn_paseo_v2/db", ChainDBDataPath("avn-paseo-v2", "rocksdb"))
assert.Equal(t, "/data/chain-data/chains/foo/db", ChainDBDataPath("foo", "rocksdb"))
assert.Equal(t, "/data/relaychain-data/chains/paseo/paritydb", RelayChainDBDataPath("paseo", "paritydb"))
assert.Equal(t, "/data/keystore", KeystorePath())
Expand Down
119 changes: 109 additions & 10 deletions internal/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"time"

"github.com/AventusDAO/substrate-bootstrap/internal/config"
"github.com/klauspost/compress/zstd"
"github.com/pierrec/lz4/v4"
"github.com/ulikunitz/xz"
Expand Down Expand Up @@ -68,7 +70,8 @@ type SyncResult struct {
// - All other URLs (e.g. Polkadot snapshots.polkadot.io) use rclone with a files.txt manifest.
//
// For Polkadot-style base URLs (no version suffix), fetches latest_version.meta.txt to resolve latest snapshot.
func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath string) (*SyncResult, error) {
// configChainID is the YAML chain_id; used for tar member path rewriting (may differ from the on-disk chains/ segment).
func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath, configChainID string) (*SyncResult, error) {
if snapshotURL == "" {
return nil, nil
}
Expand Down Expand Up @@ -101,7 +104,7 @@ func (d *Downloader) SyncIfNeeded(ctx context.Context, snapshotURL, dataPath str

if isTarURL(resolvedURL) {
result.Method = "tar"
err = d.downloadAndExtractTar(ctx, resolvedURL, dataPath)
err = d.downloadAndExtractTar(ctx, resolvedURL, dataPath, configChainID)
} else {
result.Method = "rclone"
err = d.downloadWithRclone(ctx, resolvedURL, dataPath)
Expand Down Expand Up @@ -314,7 +317,7 @@ func (d *Downloader) downloadWithRclone(ctx context.Context, snapshotURL, destPa
return nil
}

func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath string) error {
func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath, configChainID string) error {
d.logger.Info("streaming snapshot via tar",
zap.String("dest", destPath))

Expand Down Expand Up @@ -346,7 +349,7 @@ func (d *Downloader) downloadAndExtractTar(ctx context.Context, url, destPath st
}
defer decomp.Close()

if err := extractTarSecure(destPath, decomp); err != nil {
if err := extractTarSecure(destPath, decomp, configChainID); err != nil {
return fmt.Errorf("extracting snapshot: %w", err)
}

Expand Down Expand Up @@ -417,7 +420,87 @@ func newDecompressorMaybeCompressedTar(r io.Reader) (io.ReadCloser, error) {
return io.NopCloser(br), nil
}

func extractTarSecure(destPath string, r io.Reader) error {
// mapTarEntryToDestRel strips a leading chains/<id>/<db|paritydb>/ when member paths mirror the
// snapshot layout. destPath uses config.SubstrateChainsDirName; configChainID is the YAML chain_id
// so hyphenated archive paths (chains/avn-paseo-v2/...) still match.
func mapTarEntryToDestRel(archiveName, destPath, configChainID string) (rel string, mapped bool) {
storageDir := filepath.Base(destPath)
chainDir := filepath.Dir(destPath)
chainSeg := filepath.Base(chainDir)

norm := path.Clean(strings.ReplaceAll(archiveName, `\`, `/`))
if norm == "." || norm == "/" {
return archiveName, false
}

seen := make(map[string]struct{})
var candidates []string
add := func(s string) {
if _, ok := seen[s]; ok {
return
}
seen[s] = struct{}{}
candidates = append(candidates, s)
}

add(path.Join("chains", chainSeg, storageDir))
if configChainID != "" {
add(path.Join("chains", configChainID, storageDir))
normalized := config.SubstrateChainsDirName(configChainID)
if normalized != configChainID {
add(path.Join("chains", normalized, storageDir))
}
}

for _, c := range candidates {
if norm == c {
return ".", true
}
prefix := c + "/"
if strings.HasPrefix(norm, prefix) {
return norm[len(prefix):], true
}
}
return archiveName, false
}

func normalizeMappedTarRel(rel string) string {
if rel == "." {
return "."
}
return filepath.FromSlash(path.Clean(rel))
}

func validateTarRelUnderDest(destPath, rel string) error {
if filepath.IsAbs(rel) {
return fmt.Errorf("rejecting absolute path in archive: %s", rel)
}
clean := filepath.Clean(rel)
if clean == "." {
return nil
}
target := filepath.Join(destPath, clean)
out, err := filepath.Rel(destPath, target)
if err != nil || strings.HasPrefix(out, "..") || out == ".." {
return fmt.Errorf("rejecting path outside destination: %s", rel)
}
return nil
}

func validateSymlinkTargetForRel(destPath, linkRel, linkname string) error {
if filepath.IsAbs(linkname) {
return fmt.Errorf("rejecting absolute symlink target: %s -> %s", linkRel, linkname)
}
linkDir := filepath.Dir(filepath.Join(destPath, linkRel))
resolved := filepath.Clean(filepath.Join(linkDir, linkname))
rel, err := filepath.Rel(destPath, resolved)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
return fmt.Errorf("rejecting symlink escaping destination: %s -> %s", linkRel, linkname)
}
return nil
}

func extractTarSecure(destPath string, r io.Reader, configChainID string) error {
absDest, err := filepath.Abs(destPath)
if err != nil {
return fmt.Errorf("resolving dest path: %w", err)
Expand All @@ -433,11 +516,21 @@ func extractTarSecure(destPath string, r io.Reader) error {
return fmt.Errorf("reading tar: %w", err)
}

if err := validateTarPath(absDest, hdr); err != nil {
return err
mappedRel, mapped := mapTarEntryToDestRel(hdr.Name, absDest, configChainID)
var joinRel string
if mapped {
joinRel = normalizeMappedTarRel(mappedRel)
if err := validateTarRelUnderDest(absDest, joinRel); err != nil {
return err
}
} else {
if err := validateTarPath(absDest, hdr); err != nil {
return err
}
joinRel = filepath.Clean(hdr.Name)
}

target := filepath.Join(absDest, filepath.Clean(hdr.Name))
target := filepath.Join(absDest, joinRel)
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0o750); err != nil {
Expand All @@ -451,8 +544,14 @@ func extractTarSecure(destPath string, r io.Reader) error {
return err
}
case tar.TypeSymlink:
if err := validateSymlinkTarget(absDest, hdr.Name, hdr.Linkname); err != nil {
return err
if mapped {
if err := validateSymlinkTargetForRel(absDest, joinRel, hdr.Linkname); err != nil {
return err
}
} else {
if err := validateSymlinkTarget(absDest, hdr.Name, hdr.Linkname); err != nil {
return err
}
}
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
return fmt.Errorf("creating parent for symlink %s: %w", hdr.Name, err)
Expand Down
Loading
Loading