From 08e024eb7dab621d1a37d5e55db3c2a0b8b08500 Mon Sep 17 00:00:00 2001 From: Shashwat Hiregoudar Date: Wed, 10 Jun 2026 00:16:46 +0530 Subject: [PATCH] feat: add recursive download and sequential fallback - Implemented recursive crawling logic in pget.go and crawler.go. - Added --recursive and --level flags to Options. - Added fallback sequential download for non-resumable or unknown-size files. - Improved path resolution to mirror directory structures and handle conflicts. - Fixed nil pointer panic in resume logic. - Updated README and added unit tests for the crawler. --- README.md | 15 +++++- crawler.go | 68 ++++++++++++++++++++++++++ crawler_test.go | 52 ++++++++++++++++++++ download.go | 57 +++++++++++++++++++--- go.mod | 8 ++-- go.sum | 6 +++ option.go | 10 ++-- pget.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ requests.go | 14 +++--- util.go | 7 ++- 10 files changed, 337 insertions(+), 25 deletions(-) create mode 100644 crawler.go create mode 100644 crawler_test.go diff --git a/README.md b/README.md index 03f2f5e..023472a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Multi-Connection Download using parallel requests. - Fast - Resumable +- Recursive (mirroring directory structure) - Cross-compiled (windows, linux, macOS) This is an example to download [linux kernel](https://www.kernel.org/). It will be finished between 15s. @@ -59,6 +60,16 @@ You can do this cat list.txt | pget -p 2 +### Recursive Download + +Download a directory recursively (similar to `wget -r`). + + $ pget -r http://example.com/dir/ + +You can also specify the recursion depth (default is 5). + + $ pget -r -l 2 http://example.com/dir/ + ## Options ``` @@ -68,7 +79,9 @@ You can do this -o, --output output file to -t, --timeout timeout of checking request in seconds -u, --user-agent identify as - -r, --referer identify as + --referer identify as + -r, --recursive recursive download + -l, --level maximum recursion depth (default 5) --check-update check if there is update available --trace display detail error messages ``` diff --git a/crawler.go b/crawler.go new file mode 100644 index 0000000..295ca20 --- /dev/null +++ b/crawler.go @@ -0,0 +1,68 @@ +package pget + +import ( + "io" + "net/url" + "strings" + + "golang.org/x/net/html" +) + +// extractLinks parses HTML and returns a list of unique absolute URLs found in href and src attributes. +func extractLinks(r io.Reader, baseURL *url.URL) ([]string, error) { + links := make(map[string]struct{}) + tokenizer := html.NewTokenizer(r) + + for { + tokenType := tokenizer.Next() + if tokenType == html.ErrorToken { + err := tokenizer.Err() + if err == io.EOF { + break + } + return nil, err + } + + token := tokenizer.Token() + if tokenType == html.StartTagToken || tokenType == html.SelfClosingTagToken { + for _, attr := range token.Attr { + if attr.Key == "href" || attr.Key == "src" { + link := strings.TrimSpace(attr.Val) + if link == "" || strings.HasPrefix(link, "#") { + continue + } + + absURL := resolveURL(baseURL, link) + if absURL != "" { + links[absURL] = struct{}{} + } + } + } + } + } + + result := make([]string, 0, len(links)) + for link := range links { + result = append(result, link) + } + return result, nil +} + +// resolveURL converts a relative URL to an absolute one based on the base URL. +func resolveURL(base *url.URL, relative string) string { + u, err := url.Parse(relative) + if err != nil { + return "" + } + return base.ResolveReference(u).String() +} + +// shouldCrawl checks if a URL should be crawled based on the initial base URL. +// It ensures we stay on the same domain and under the same path. +func shouldCrawl(baseURL, targetURL *url.URL) bool { + if baseURL.Host != targetURL.Host { + return false + } + // Ensure target is under the base path + return strings.HasPrefix(targetURL.Path, baseURL.Path) +} diff --git a/crawler_test.go b/crawler_test.go new file mode 100644 index 0000000..6a54d52 --- /dev/null +++ b/crawler_test.go @@ -0,0 +1,52 @@ +package pget + +import ( + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractLinks(t *testing.T) { + htmlContent := ` + + + Link 1 + External Link + + + Anchor + Empty + + + ` + baseURL, _ := url.Parse("http://example.com/start/") + links, err := extractLinks(strings.NewReader(htmlContent), baseURL) + + assert.NoError(t, err) + assert.Contains(t, links, "http://example.com/page1") + assert.Contains(t, links, "http://other.com/page2") + assert.Contains(t, links, "http://example.com/start/images/logo.png") + assert.Contains(t, links, "http://example.com/js/app.js") + assert.Len(t, links, 4) +} + +func TestShouldCrawl(t *testing.T) { + baseURL, _ := url.Parse("http://example.com/dir/") + + tests := []struct { + target string + expected bool + }{ + {"http://example.com/dir/file.html", true}, + {"http://example.com/dir/subdir/file.html", true}, + {"http://example.com/otherdir/file.html", false}, + {"http://other.com/dir/file.html", false}, + } + + for _, tc := range tests { + targetURL, _ := url.Parse(tc.target) + assert.Equal(t, tc.expected, shouldCrawl(baseURL, targetURL), tc.target) + } +} diff --git a/download.go b/download.go index 1e8fdeb..4124fba 100644 --- a/download.go +++ b/download.go @@ -139,19 +139,19 @@ func WithReferer(referer string) DownloadOption { } func Download(ctx context.Context, c *DownloadConfig, opts ...DownloadOption) error { - partialDir := getPartialDirname(c.Dirname, c.Filename, c.Procs) - - // create download location - if err := os.MkdirAll(partialDir, 0755); err != nil { - return errors.Wrap(err, "failed to mkdir for download location") - } - c.makeRequestOption = &makeRequestOption{} for _, opt := range opts { opt(c) } + // Fallback to sequential download if ContentLength is unknown or only 1 proc is requested + if c.ContentLength <= 0 || c.Procs <= 1 { + return sequentialDownload(ctx, c) + } + + partialDir := getPartialDirname(c.Dirname, c.Filename, c.Procs) + tasks := assignTasks(&assignTasksConfig{ Procs: c.Procs, TaskSize: c.ContentLength / int64(c.Procs), @@ -281,3 +281,46 @@ func bindFiles(c *DownloadConfig, partialDir string) error { return nil } + +func sequentialDownload(ctx context.Context, c *DownloadConfig) error { + destPath := filepath.Join(c.Dirname, c.Filename) + output, err := os.Create(destPath) + if err != nil { + return errors.Wrapf(err, "failed to create: %s", destPath) + } + defer output.Close() + + req, err := http.NewRequest("GET", c.URLs[0], nil) + if err != nil { + return errors.Wrap(err, "failed to make request") + } + req = req.WithContext(ctx) + + // set useragent + req.Header.Set("User-Agent", c.makeRequestOption.useragent) + + // set referer + if c.makeRequestOption.referer != "" { + req.Header.Set("Referer", c.makeRequestOption.referer) + } + + resp, err := c.Client.Do(req) + if err != nil { + return errors.Wrap(err, "failed to get response") + } + defer resp.Body.Close() + + var bar *pb.ProgressBar + var rd io.Reader = resp.Body + if c.ContentLength > 0 { + bar = pb.Start64(c.ContentLength).SetWriter(stdout).Set(pb.Bytes, true) + defer bar.Finish() + rd = bar.NewProxyReader(resp.Body) + } + + if _, err := io.Copy(output, rd); err != nil { + return errors.Wrap(err, "failed to write response body") + } + + return nil +} diff --git a/go.mod b/go.mod index a895b7d..f14e498 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/Code-Hex/pget -go 1.21 +go 1.25.0 require ( github.com/Code-Hex/updater v0.0.0-20160712085121-c3f278672520 @@ -32,8 +32,8 @@ require ( github.com/rivo/uniseg v0.2.0 // indirect github.com/ulikunitz/xz v0.5.8 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect - golang.org/x/net v0.0.0-20161013035702-8b4af36cd21a // indirect - golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 // indirect - golang.org/x/term v0.0.0-20210317153231-de623e64d2a6 // indirect + golang.org/x/net v0.55.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/term v0.43.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index acf64cb..896d7fe 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofm github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= golang.org/x/net v0.0.0-20161013035702-8b4af36cd21a h1:YEFEcqrj8fWeC0px2Ha5IrK20xodii3wn+N+jzuFRKQ= golang.org/x/net v0.0.0-20161013035702-8b4af36cd21a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.0.0-20161004233620-1ae7c7b29e06 h1:pRVhPB331E/b1+A7Y9d/3ZkgE5LNxnP/q5ChiqPf79Q= golang.org/x/sync v0.0.0-20161004233620-1ae7c7b29e06/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -71,8 +73,12 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210319071255-635bc2c9138d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20210317153231-de623e64d2a6 h1:EC6+IGYTjPpRfv9a2b/6Puw0W+hLtAhkV1tPsXhutqs= golang.org/x/term v0.0.0-20210317153231-de623e64d2a6/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/option.go b/option.go index c96bb40..d6ef3df 100644 --- a/option.go +++ b/option.go @@ -16,9 +16,11 @@ type Options struct { Output string `short:"o" long:"output"` Timeout int `short:"t" long:"timeout" default:"10"` UserAgent string `short:"u" long:"user-agent"` - Referer string `short:"r" long:"referer"` + Referer string `long:"referer"` Update bool `long:"check-update"` Trace bool `long:"trace"` + Recursive bool `short:"r" long:"recursive"` + Level int `short:"l" long:"level" default:"5"` } func (opts *Options) parse(argv []string, version string) ([]string, error) { @@ -45,7 +47,9 @@ func (opts Options) usage(version string) []byte { -o, --output output file to -t, --timeout timeout of checking request in seconds (default 10s) -u, --user-agent identify as - -r, --referer identify as + --referer identify as + -r, --recursive recursive download + -l, --level maximum recursion depth (default 5) --check-update check if there is update available --trace display detail error messages `, version) @@ -58,7 +62,7 @@ func (opts Options) isupdate(version string) ([]byte, error) { if err != nil { return nil, err } - fmt.Fprintf(&buf, result+"\n") + fmt.Fprintf(&buf, "%s\n", result) return buf.Bytes(), nil } diff --git a/pget.go b/pget.go index b5e53a0..e8d8fb8 100644 --- a/pget.go +++ b/pget.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net/url" "os" "path/filepath" "runtime" @@ -26,6 +27,8 @@ type Pget struct { timeout int useragent string referer string + recursive bool + level int } // New for pget package @@ -43,6 +46,10 @@ func (pget *Pget) Run(ctx context.Context, version string, args []string) error return errTop(err) } + if pget.recursive { + return pget.RunRecursive(ctx, version) + } + // TODO(codehex): calc maxIdleConnsPerHost client := newDownloadClient(16) @@ -133,6 +140,9 @@ func (pget *Pget) Ready(version string, args []string) error { pget.referer = opts.Referer } + pget.recursive = opts.Recursive + pget.level = opts.Level + return nil } @@ -204,3 +214,118 @@ func (pget *Pget) parseURLs() error { return nil } + +type crawlTarget struct { + url *url.URL + depth int +} + +func (pget *Pget) RunRecursive(ctx context.Context, version string) error { + if len(pget.URLs) == 0 { + return errors.New("URL is required") + } + + client := newDownloadClient(16) + + // We only support one base URL for recursive download for now + startURLStr := pget.URLs[0] + baseURL, err := url.Parse(startURLStr) + if err != nil { + return errors.Wrap(err, "failed to parse start URL") + } + + queue := []crawlTarget{{url: baseURL, depth: 0}} + visited := make(map[string]bool) + + for len(queue) > 0 { + target := queue[0] + queue = queue[1:] + + uStr := target.url.String() + if visited[uStr] { + continue + } + visited[uStr] = true + + if target.depth > pget.level { + continue + } + + fmt.Fprintf(stdout, "Crawling: %s (depth: %d)\n", uStr, target.depth) + + // 1. Check if it's a file or HTML + checkTarget, err := Check(ctx, &CheckConfig{ + URLs: []string{uStr}, + Timeout: time.Duration(pget.timeout) * time.Second, + Client: client, + }) + if err != nil { + fmt.Fprintf(stdout, "Warning: failed to check %s: %v\n", uStr, err) + continue + } + + // 2. Prepare local path + // We want to mirror the structure. If pget.Output is set, we use it as the root. + root := pget.Output + if root == "" { + root = "." + } + + localPath := filepath.Join(root, target.url.Host, target.url.Path) + // If it's a directory-like URL, use index.html + if strings.HasSuffix(target.url.Path, "/") || target.url.Path == "" || !strings.Contains(filepath.Base(target.url.Path), ".") { + localPath = filepath.Join(localPath, "index.html") + } + + dir, filename := filepath.Split(localPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return errors.Wrapf(err, "failed to create directory %s", dir) + } + + // 3. Download + opts := []DownloadOption{ + WithUserAgent(pget.useragent, version), + WithReferer(pget.referer), + } + + err = Download(ctx, &DownloadConfig{ + Filename: filename, + Dirname: dir, + ContentLength: checkTarget.ContentLength, + Procs: pget.Procs, + URLs: []string{uStr}, + Client: client, + }, opts...) + if err != nil { + fmt.Fprintf(stdout, "Warning: failed to download %s: %v\n", uStr, err) + continue + } + + // 4. If it was an HTML file, parse for links + if target.depth < pget.level { + // Read the downloaded file to parse links + f, err := os.Open(filepath.Join(dir, filename)) + if err != nil { + continue + } + + links, err := extractLinks(f, target.url) + f.Close() + if err != nil { + continue + } + + for _, link := range links { + linkURL, err := url.Parse(link) + if err != nil { + continue + } + if shouldCrawl(baseURL, linkURL) { + queue = append(queue, crawlTarget{url: linkURL, depth: target.depth + 1}) + } + } + } + } + + return nil +} diff --git a/requests.go b/requests.go index 731b310..65bb696 100644 --- a/requests.go +++ b/requests.go @@ -119,12 +119,10 @@ func getMirrorInfo(ctx context.Context, client *http.Client, url string) (*mirro return nil, errors.Wrap(err, "failed to head request") } - if resp.Header.Get("Accept-Ranges") != "bytes" { - return nil, errors.New("does not support range request") - } - - if resp.ContentLength <= 0 { - return nil, errors.New("invalid content length") + // Some servers might return 0 or -1 for Content-Length for dynamic content + contentLength := resp.ContentLength + if contentLength < 0 { + contentLength = 0 } filename := "" @@ -139,14 +137,14 @@ func getMirrorInfo(ctx context.Context, client *http.Client, url string) (*mirro if isNotLastURL(_url, url) { return &mirrorInfo{ RetrievedURL: _url, - ContentLength: resp.ContentLength, + ContentLength: contentLength, Filename: filename, }, nil } return &mirrorInfo{ RetrievedURL: url, - ContentLength: resp.ContentLength, + ContentLength: contentLength, Filename: filename, }, nil } diff --git a/util.go b/util.go index 894c9e7..d98910e 100644 --- a/util.go +++ b/util.go @@ -29,10 +29,13 @@ func checkProgress(dirname string) (int64, error) { func subDirsize(dirname string) (int64, error) { var size int64 err := filepath.Walk(dirname, func(_ string, info os.FileInfo, err error) error { - if !info.IsDir() { + if err != nil { + return err + } + if info != nil && !info.IsDir() { size += info.Size() } - return err + return nil }) return size, err