From 08df9258a532b653c243e077e82491dbe62ad854 Mon Sep 17 00:00:00 2001 From: Philipp Tanlak Date: Sat, 23 Sep 2023 17:41:57 +0200 Subject: refactor scraper into modules --- cmd/flyscrape/dev.go | 27 ++- cmd/flyscrape/main.go | 8 + cmd/flyscrape/run.go | 37 +--- js.go | 27 +-- js_test.go | 26 +-- mock.go | 22 +++ module.go | 65 +++++++ modules/depth/depth.go | 30 +++ modules/depth/depth_test.go | 47 +++++ modules/domainfilter/domainfilter.go | 62 ++++++ modules/domainfilter/domainfilter_test.go | 92 +++++++++ modules/followlinks/followlinks.go | 30 +++ modules/followlinks/followlinks_test.go | 39 ++++ modules/jsonprinter/jsonprinter.go | 47 +++++ modules/jsonprinter/jsonprinter_test.go | 47 +++++ modules/ratelimit/ratelimit.go | 54 ++++++ modules/ratelimit/ratelimit_test.go | 45 +++++ modules/starturl/starturl.go | 30 +++ modules/starturl/starturl_test.go | 31 +++ modules/urlfilter/urlfilter.go | 85 ++++++++ modules/urlfilter/urlfilter_test.go | 71 +++++++ scrape.go | 310 +++++++++++++----------------- scrape_test.go | 266 ------------------------- utils.go | 4 + 24 files changed, 994 insertions(+), 508 deletions(-) create mode 100644 mock.go create mode 100644 module.go create mode 100644 modules/depth/depth.go create mode 100644 modules/depth/depth_test.go create mode 100644 modules/domainfilter/domainfilter.go create mode 100644 modules/domainfilter/domainfilter_test.go create mode 100644 modules/followlinks/followlinks.go create mode 100644 modules/followlinks/followlinks_test.go create mode 100644 modules/jsonprinter/jsonprinter.go create mode 100644 modules/jsonprinter/jsonprinter_test.go create mode 100644 modules/ratelimit/ratelimit.go create mode 100644 modules/ratelimit/ratelimit_test.go create mode 100644 modules/starturl/starturl.go create mode 100644 modules/starturl/starturl_test.go create mode 100644 modules/urlfilter/urlfilter.go create mode 100644 modules/urlfilter/urlfilter_test.go delete mode 100644 scrape_test.go diff --git a/cmd/flyscrape/dev.go b/cmd/flyscrape/dev.go index 85ac1a1..169e6d3 100644 --- a/cmd/flyscrape/dev.go +++ b/cmd/flyscrape/dev.go @@ -56,23 +56,22 @@ func (c *DevCommand) Run(args []string) error { return nil } - opts.Depth = 0 - scr := flyscrape.Scraper{ - ScrapeOptions: opts, - ScrapeFunc: scrape, - FetchFunc: fetch, - } + scraper := flyscrape.NewScraper() + scraper.ScrapeFunc = scrape + flyscrape.LoadModules(scraper, opts) - result := <-scr.Scrape() - screen.Clear() - screen.MoveTopLeft() + scraper.Run() - if result.Error != nil { - log.Println(result.Error) - return nil - } + scraper.OnResponse(func(resp *flyscrape.Response) { + screen.Clear() + screen.MoveTopLeft() + if resp.Error != nil { + log.Println(resp.Error) + return + } + fmt.Println(flyscrape.PrettyPrint(resp.ScrapeResult, "")) + }) - fmt.Println(flyscrape.PrettyPrint(result, "")) return nil }) if err != nil && err != flyscrape.StopWatch { diff --git a/cmd/flyscrape/main.go b/cmd/flyscrape/main.go index 4e448bb..bac411e 100644 --- a/cmd/flyscrape/main.go +++ b/cmd/flyscrape/main.go @@ -11,6 +11,14 @@ import ( "log" "os" "strings" + + _ "github.com/philippta/flyscrape/modules/depth" + _ "github.com/philippta/flyscrape/modules/domainfilter" + _ "github.com/philippta/flyscrape/modules/followlinks" + _ "github.com/philippta/flyscrape/modules/jsonprinter" + _ "github.com/philippta/flyscrape/modules/ratelimit" + _ "github.com/philippta/flyscrape/modules/starturl" + _ "github.com/philippta/flyscrape/modules/urlfilter" ) func main() { diff --git a/cmd/flyscrape/run.go b/cmd/flyscrape/run.go index 987d0e0..22f41fd 100644 --- a/cmd/flyscrape/run.go +++ b/cmd/flyscrape/run.go @@ -18,8 +18,6 @@ type RunCommand struct{} func (c *RunCommand) Run(args []string) error { fs := flag.NewFlagSet("flyscrape-run", flag.ContinueOnError) - noPrettyPrint := fs.Bool("no-pretty-print", false, "no-pretty-print") - proxy := fs.String("proxy", "", "proxy") fs.Usage = c.Usage if err := fs.Parse(args); err != nil { @@ -41,32 +39,17 @@ func (c *RunCommand) Run(args []string) error { return fmt.Errorf("failed to compile script: %w", err) } - svc := flyscrape.Scraper{ - ScrapeOptions: opts, - ScrapeFunc: scrape, - } - if *proxy != "" { - svc.FetchFunc = flyscrape.ProxiedFetch(*proxy) - } + scraper := flyscrape.NewScraper() + scraper.ScrapeFunc = scrape + flyscrape.LoadModules(scraper, opts) count := 0 start := time.Now() - for result := range svc.Scrape() { - if count > 0 { - fmt.Println(",") - } - if count == 0 { - fmt.Println("[") - } - if *noPrettyPrint { - fmt.Print(flyscrape.Print(result, " ")) - } else { - fmt.Print(flyscrape.PrettyPrint(result, " ")) - } + scraper.OnResponse(func(resp *flyscrape.Response) { count++ - } - fmt.Println("\n]") + }) + scraper.Run() log.Printf("Scraped %d websites in %v\n", count, time.Since(start)) return nil @@ -80,18 +63,10 @@ Usage: flyscrape run SCRIPT -Arguments: - - -no-pretty-print - Disables pretty printing of scrape results. - Examples: # Run the script. $ flyscrape run example.js - - # Run the script with pretty printing disabled. - $ flyscrape run -no-pretty-print example.js `[1:]) } diff --git a/js.go b/js.go index 5a20ed7..ce0efc1 100644 --- a/js.go +++ b/js.go @@ -16,6 +16,8 @@ import ( v8 "rogchap.com/v8go" ) +type Options []byte + type TransformError struct { Line int Column int @@ -26,10 +28,10 @@ func (err TransformError) Error() string { return fmt.Sprintf("%d:%d: %s", err.Line, err.Column, err.Text) } -func Compile(src string) (ScrapeOptions, ScrapeFunc, error) { +func Compile(src string) (Options, ScrapeFunc, error) { src, err := build(src) if err != nil { - return ScrapeOptions{}, nil, err + return nil, nil, err } return vm(src) } @@ -56,31 +58,30 @@ func build(src string) (string, error) { return string(res.Code), nil } -func vm(src string) (ScrapeOptions, ScrapeFunc, error) { +func vm(src string) (Options, ScrapeFunc, error) { ctx := v8.NewContext() ctx.RunScript("var module = {}", "main.js") if _, err := ctx.RunScript(removeIIFE(js.Flyscrape), "main.js"); err != nil { - return ScrapeOptions{}, nil, fmt.Errorf("running flyscrape bundle: %w", err) + return nil, nil, fmt.Errorf("running flyscrape bundle: %w", err) } if _, err := ctx.RunScript(`const require = () => require_flyscrape();`, "main.js"); err != nil { - return ScrapeOptions{}, nil, fmt.Errorf("creating require function: %w", err) + return nil, nil, fmt.Errorf("creating require function: %w", err) } if _, err := ctx.RunScript(removeIIFE(src), "main.js"); err != nil { - return ScrapeOptions{}, nil, fmt.Errorf("running user script: %w", err) + return nil, nil, fmt.Errorf("running user script: %w", err) } - var opts ScrapeOptions - optsJSON, err := ctx.RunScript("JSON.stringify(options)", "main.js") + cfg, err := ctx.RunScript("JSON.stringify(options)", "main.js") if err != nil { - return ScrapeOptions{}, nil, fmt.Errorf("reading options: %w", err) + return nil, nil, fmt.Errorf("reading options: %w", err) } - if err := json.Unmarshal([]byte(optsJSON.String()), &opts); err != nil { - return ScrapeOptions{}, nil, fmt.Errorf("decoding options json: %w", err) + if !cfg.IsString() { + return nil, nil, fmt.Errorf("options is not a string") } scrape := func(params ScrapeParams) (any, error) { - suffix := randSeq(10) + suffix := randSeq(16) ctx.Global().Set("html_"+suffix, params.HTML) ctx.Global().Set("url_"+suffix, params.URL) data, err := ctx.RunScript(fmt.Sprintf(`JSON.stringify(stdin_default({html: html_%s, url: url_%s}))`, suffix, suffix), "main.js") @@ -96,7 +97,7 @@ func vm(src string) (ScrapeOptions, ScrapeFunc, error) { return obj, nil } - return opts, scrape, nil + return Options(cfg.String()), scrape, nil } func randSeq(n int) string { diff --git a/js_test.go b/js_test.go index d1010b9..7496c68 100644 --- a/js_test.go +++ b/js_test.go @@ -5,6 +5,7 @@ package flyscrape_test import ( + "encoding/json" "testing" "github.com/philippta/flyscrape" @@ -81,24 +82,25 @@ func TestJSOptions(t *testing.T) { url: 'http://localhost/', depth: 5, allowedDomains: ['example.com'], - blockedDomains: ['google.com'], - allowedURLs: ['/foo'], - blockedURLs: ['/bar'], - proxy: 'http://proxy/', - rate: 1, } export default function() {} ` - opts, _, err := flyscrape.Compile(js) + rawOpts, _, err := flyscrape.Compile(js) require.NoError(t, err) - require.Equal(t, flyscrape.ScrapeOptions{ + + type options struct { + URL string `json:"url"` + Depth int `json:"depth"` + AllowedDomains []string `json:"allowedDomains"` + } + + var opts options + err = json.Unmarshal(rawOpts, &opts) + require.NoError(t, err) + + require.Equal(t, options{ URL: "http://localhost/", Depth: 5, AllowedDomains: []string{"example.com"}, - BlockedDomains: []string{"google.com"}, - AllowedURLs: []string{"/foo"}, - BlockedURLs: []string{"/bar"}, - Proxy: "http://proxy/", - Rate: 1, }, opts) } diff --git a/mock.go b/mock.go new file mode 100644 index 0000000..44b8837 --- /dev/null +++ b/mock.go @@ -0,0 +1,22 @@ +package flyscrape + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +func MockTransport(statusCode int, html string) func(*http.Request) (*http.Response, error) { + return func(*http.Request) (*http.Response, error) { + return MockResponse(statusCode, html) + } +} + +func MockResponse(statusCode int, html string) (*http.Response, error) { + return &http.Response{ + StatusCode: statusCode, + Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)), + Body: io.NopCloser(strings.NewReader(html)), + }, nil +} diff --git a/module.go b/module.go new file mode 100644 index 0000000..bc90c02 --- /dev/null +++ b/module.go @@ -0,0 +1,65 @@ +package flyscrape + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" +) + +type Module interface { + ID() string +} + +type Transport interface { + Transport(*http.Request) (*http.Response, error) +} + +type CanRequest interface { + CanRequest(url string, depth int) bool +} + +type OnRequest interface { + OnRequest(*Request) +} +type OnResponse interface { + OnResponse(*Response) +} + +type OnLoad interface { + OnLoad(Visitor) +} + +type OnComplete interface { + OnComplete() +} + +func RegisterModule(m Module) { + id := m.ID() + if id == "" { + panic("module id is missing") + } + + globalModulesMu.Lock() + defer globalModulesMu.Unlock() + + if _, ok := globalModules[id]; ok { + panic(fmt.Sprintf("module %s already registered", id)) + } + globalModules[id] = m +} + +func LoadModules(s *Scraper, opts Options) { + globalModulesMu.RLock() + defer globalModulesMu.RUnlock() + + for _, mod := range globalModules { + json.Unmarshal(opts, mod) + s.LoadModule(mod) + } +} + +var ( + globalModules = map[string]Module{} + globalModulesMu sync.RWMutex +) diff --git a/modules/depth/depth.go b/modules/depth/depth.go new file mode 100644 index 0000000..5efedc8 --- /dev/null +++ b/modules/depth/depth.go @@ -0,0 +1,30 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package depth + +import ( + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + Depth int `json:"depth"` +} + +func (m *Module) ID() string { + return "depth" +} + +func (m *Module) CanRequest(url string, depth int) bool { + return depth <= m.Depth +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.CanRequest = (*Module)(nil) +) diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go new file mode 100644 index 0000000..309e628 --- /dev/null +++ b/modules/depth/depth_test.go @@ -0,0 +1,47 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package depth_test + +import ( + "net/http" + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/depth" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestDepth(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&depth.Module{Depth: 2}) + + scraper.SetTransport(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "http://www.example.com/": + return flyscrape.MockResponse(200, `Google`) + case "http://www.google.com/": + return flyscrape.MockResponse(200, `DuckDuckGo`) + case "http://www.duckduckgo.com/": + return flyscrape.MockResponse(200, `Example`) + } + return flyscrape.MockResponse(200, "") + }) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 3) + require.Contains(t, urls, "http://www.example.com/") + require.Contains(t, urls, "http://www.google.com/") + require.Contains(t, urls, "http://www.duckduckgo.com/") +} diff --git a/modules/domainfilter/domainfilter.go b/modules/domainfilter/domainfilter.go new file mode 100644 index 0000000..b892882 --- /dev/null +++ b/modules/domainfilter/domainfilter.go @@ -0,0 +1,62 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package domainfilter + +import ( + "github.com/nlnwa/whatwg-url/url" + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + URL string `json:"url"` + AllowedDomains []string `json:"allowedDomains"` + BlockedDomains []string `json:"blockedDomains"` +} + +func (m *Module) ID() string { + return "domainfilter" +} + +func (m *Module) OnLoad(v flyscrape.Visitor) { + if u, err := url.Parse(m.URL); err == nil { + m.AllowedDomains = append(m.AllowedDomains, u.Host()) + } +} + +func (m *Module) CanRequest(rawurl string, depth int) bool { + u, err := url.Parse(rawurl) + if err != nil { + return false + } + + host := u.Host() + ok := false + + for _, domain := range m.AllowedDomains { + if domain == "*" || host == domain { + ok = true + break + } + } + + for _, domain := range m.BlockedDomains { + if host == domain { + ok = false + break + } + } + + return ok +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.CanRequest = (*Module)(nil) + _ flyscrape.OnLoad = (*Module)(nil) +) diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go new file mode 100644 index 0000000..97bdc9c --- /dev/null +++ b/modules/domainfilter/domainfilter_test.go @@ -0,0 +1,92 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package domainfilter_test + +import ( + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/domainfilter" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestDomainfilterAllowed(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"www.google.com"}, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + Google + DuckDuckGo`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 2) + require.Contains(t, urls, "http://www.example.com") + require.Contains(t, urls, "http://www.google.com/") +} + +func TestDomainfilterAllowedAll(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"*"}, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + Google + DuckDuckGo`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 3) + require.Contains(t, urls, "http://www.example.com") + require.Contains(t, urls, "http://www.duckduckgo.com/") + require.Contains(t, urls, "http://www.google.com/") +} + +func TestDomainfilterBlocked(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"*"}, + BlockedDomains: []string{"www.google.com"}, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + Google + DuckDuckGo`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 2) + require.Contains(t, urls, "http://www.example.com") + require.Contains(t, urls, "http://www.duckduckgo.com/") +} diff --git a/modules/followlinks/followlinks.go b/modules/followlinks/followlinks.go new file mode 100644 index 0000000..dde0e90 --- /dev/null +++ b/modules/followlinks/followlinks.go @@ -0,0 +1,30 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package followlinks + +import ( + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct{} + +func (m *Module) ID() string { + return "followlinks" +} + +func (m *Module) OnResponse(resp *flyscrape.Response) { + for _, link := range flyscrape.ParseLinks(resp.HTML, resp.URL) { + resp.Visit(link) + } +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.OnResponse = (*Module)(nil) +) diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go new file mode 100644 index 0000000..03c3a6b --- /dev/null +++ b/modules/followlinks/followlinks_test.go @@ -0,0 +1,39 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package followlinks_test + +import ( + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestFollowLinks(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) + scraper.LoadModule(&followlinks.Module{}) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + Baz + Baz + Google`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 5) + require.Contains(t, urls, "http://www.example.com/baz") + require.Contains(t, urls, "http://www.example.com/foo/bar") + require.Contains(t, urls, "http://www.example.com/foo/baz") + require.Contains(t, urls, "http://www.google.com/") + require.Contains(t, urls, "http://www.google.com/baz") +} diff --git a/modules/jsonprinter/jsonprinter.go b/modules/jsonprinter/jsonprinter.go new file mode 100644 index 0000000..3936277 --- /dev/null +++ b/modules/jsonprinter/jsonprinter.go @@ -0,0 +1,47 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package jsonprinter + +import ( + "fmt" + + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + first bool +} + +func (m *Module) ID() string { + return "jsonprinter" +} + +func (m *Module) OnResponse(resp *flyscrape.Response) { + if resp.Error == nil && resp.Data == nil { + return + } + + if m.first { + fmt.Println("[") + } else { + fmt.Println(",") + } + + fmt.Print(flyscrape.PrettyPrint(resp.ScrapeResult, " ")) +} + +func (m *Module) OnComplete() { + fmt.Println("\n]") +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.OnResponse = (*Module)(nil) + _ flyscrape.OnComplete = (*Module)(nil) +) diff --git a/modules/jsonprinter/jsonprinter_test.go b/modules/jsonprinter/jsonprinter_test.go new file mode 100644 index 0000000..29cc438 --- /dev/null +++ b/modules/jsonprinter/jsonprinter_test.go @@ -0,0 +1,47 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package jsonprinter_test + +import ( + "net/http" + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/depth" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestDepth(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&depth.Module{Depth: 2}) + + scraper.SetTransport(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "http://www.example.com/": + return flyscrape.MockResponse(200, `Google`) + case "http://www.google.com/": + return flyscrape.MockResponse(200, `DuckDuckGo`) + case "http://www.duckduckgo.com/": + return flyscrape.MockResponse(200, `Example`) + } + return flyscrape.MockResponse(200, "") + }) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 3) + require.Contains(t, urls, "http://www.example.com/") + require.Contains(t, urls, "http://www.google.com/") + require.Contains(t, urls, "http://www.duckduckgo.com/") +} diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go new file mode 100644 index 0000000..b02f5d5 --- /dev/null +++ b/modules/ratelimit/ratelimit.go @@ -0,0 +1,54 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package ratelimit + +import ( + "time" + + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + Rate float64 `json:"rate"` + + ticker *time.Ticker + semaphore chan struct{} +} + +func (m *Module) ID() string { + return "ratelimit" +} + +func (m *Module) OnLoad(v flyscrape.Visitor) { + rate := time.Duration(float64(time.Second) / m.Rate) + + m.ticker = time.NewTicker(rate) + m.semaphore = make(chan struct{}, 1) + + go func() { + for range m.ticker.C { + m.semaphore <- struct{}{} + } + }() +} + +func (m *Module) OnRequest(_ *flyscrape.Request) { + <-m.semaphore +} + +func (m *Module) OnComplete() { + m.ticker.Stop() +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.OnRequest = (*Module)(nil) + _ flyscrape.OnLoad = (*Module)(nil) + _ flyscrape.OnComplete = (*Module)(nil) +) diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..c166371 --- /dev/null +++ b/modules/ratelimit/ratelimit_test.go @@ -0,0 +1,45 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package ratelimit_test + +import ( + "testing" + "time" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/ratelimit" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestRatelimit(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&ratelimit.Module{ + Rate: 100, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, `foo`)) + + var times []time.Time + scraper.OnRequest(func(req *flyscrape.Request) { + times = append(times, time.Now()) + }) + + start := time.Now() + + scraper.Run() + + first := times[0].Add(-10 * time.Millisecond) + second := times[1].Add(-20 * time.Millisecond) + + require.Less(t, first.Sub(start), 2*time.Millisecond) + require.Less(t, second.Sub(start), 2*time.Millisecond) + + require.Less(t, start.Sub(first), 2*time.Millisecond) + require.Less(t, start.Sub(second), 2*time.Millisecond) +} diff --git a/modules/starturl/starturl.go b/modules/starturl/starturl.go new file mode 100644 index 0000000..b2e6c47 --- /dev/null +++ b/modules/starturl/starturl.go @@ -0,0 +1,30 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package starturl + +import ( + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + URL string `json:"url"` +} + +func (m *Module) ID() string { + return "starturl" +} + +func (m *Module) OnLoad(v flyscrape.Visitor) { + v.Visit(m.URL) +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.OnLoad = (*Module)(nil) +) diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go new file mode 100644 index 0000000..647e197 --- /dev/null +++ b/modules/starturl/starturl_test.go @@ -0,0 +1,31 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package starturl_test + +import ( + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestFollowLinks(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) + scraper.SetTransport(flyscrape.MockTransport(200, "")) + + var url string + var depth int + scraper.OnRequest(func(req *flyscrape.Request) { + url = req.URL + depth = req.Depth + }) + + scraper.Run() + + require.Equal(t, "http://www.example.com/foo/bar", url) + require.Equal(t, 0, depth) +} diff --git a/modules/urlfilter/urlfilter.go b/modules/urlfilter/urlfilter.go new file mode 100644 index 0000000..14576f0 --- /dev/null +++ b/modules/urlfilter/urlfilter.go @@ -0,0 +1,85 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package urlfilter + +import ( + "regexp" + + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(new(Module)) +} + +type Module struct { + URL string `json:"url"` + AllowedURLs []string `json:"allowedURLs"` + BlockedURLs []string `json:"blockedURLs"` + + allowedURLsRE []*regexp.Regexp + blockedURLsRE []*regexp.Regexp +} + +func (m *Module) ID() string { + return "urlfilter" +} + +func (m *Module) OnLoad(v flyscrape.Visitor) { + for _, pat := range m.AllowedURLs { + re, err := regexp.Compile(pat) + if err != nil { + continue + } + m.allowedURLsRE = append(m.allowedURLsRE, re) + } + + for _, pat := range m.BlockedURLs { + re, err := regexp.Compile(pat) + if err != nil { + continue + } + m.blockedURLsRE = append(m.blockedURLsRE, re) + } +} + +func (m *Module) CanRequest(rawurl string, depth int) bool { + // allow root url + if rawurl == m.URL { + return true + } + + // allow if no filter is set + if len(m.allowedURLsRE) == 0 && len(m.blockedURLsRE) == 0 { + return true + } + + ok := false + if len(m.allowedURLsRE) == 0 { + ok = true + } + + for _, re := range m.allowedURLsRE { + if re.MatchString(rawurl) { + ok = true + break + } + } + + for _, re := range m.blockedURLsRE { + if re.MatchString(rawurl) { + ok = false + break + } + } + + return ok +} + +var ( + _ flyscrape.Module = (*Module)(nil) + _ flyscrape.CanRequest = (*Module)(nil) + _ flyscrape.OnLoad = (*Module)(nil) +) diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go new file mode 100644 index 0000000..e383a32 --- /dev/null +++ b/modules/urlfilter/urlfilter_test.go @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package urlfilter_test + +import ( + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/philippta/flyscrape/modules/urlfilter" + "github.com/stretchr/testify/require" +) + +func TestURLFilterAllowed(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&urlfilter.Module{ + URL: "http://www.example.com/", + AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + 123 + ABC + bar + barz`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 3) + require.Contains(t, urls, "http://www.example.com/") + require.Contains(t, urls, "http://www.example.com/foo?id=123") + require.Contains(t, urls, "http://www.example.com/bar") +} + +func TestURLFilterBlocked(t *testing.T) { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) + scraper.LoadModule(&followlinks.Module{}) + scraper.LoadModule(&urlfilter.Module{ + URL: "http://www.example.com/", + BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, + }) + + scraper.SetTransport(flyscrape.MockTransport(200, ` + 123 + ABC + bar + barz`)) + + var urls []string + scraper.OnRequest(func(req *flyscrape.Request) { + urls = append(urls, req.URL) + }) + + scraper.Run() + + require.Len(t, urls, 3) + require.Contains(t, urls, "http://www.example.com/") + require.Contains(t, urls, "http://www.example.com/foo?id=ABC") + require.Contains(t, urls, "http://www.example.com/barz") +} diff --git a/scrape.go b/scrape.go index 8b6ce11..42b3c10 100644 --- a/scrape.go +++ b/scrape.go @@ -5,8 +5,9 @@ package flyscrape import ( + "io" "log" - "regexp" + "net/http" "strings" "sync" "time" @@ -21,17 +22,6 @@ type ScrapeParams struct { URL string } -type ScrapeOptions struct { - URL string `json:"url"` - AllowedDomains []string `json:"allowedDomains"` - BlockedDomains []string `json:"blockedDomains"` - AllowedURLs []string `json:"allowedURLs"` - BlockedURLs []string `json:"blockedURLs"` - Proxy string `json:"proxy"` - Depth int `json:"depth"` - Rate float64 `json:"rate"` -} - type ScrapeResult struct { URL string `json:"url"` Data any `json:"data,omitempty"` @@ -48,204 +38,203 @@ type ScrapeFunc func(ScrapeParams) (any, error) type FetchFunc func(url string) (string, error) -type target struct { - url string - depth int +type Visitor interface { + Visit(url string) + MarkVisited(url string) } +type ( + Request struct { + URL string + Depth int + } + + Response struct { + ScrapeResult + HTML string + Visit func(url string) + } + + target struct { + url string + depth int + } +) + type Scraper struct { - ScrapeOptions ScrapeOptions - ScrapeFunc ScrapeFunc - FetchFunc FetchFunc - - visited *hashmap.Map[string, struct{}] - wg *sync.WaitGroup - jobs chan target - results chan ScrapeResult - allowedURLsRE []*regexp.Regexp - blockedURLsRE []*regexp.Regexp + ScrapeFunc ScrapeFunc + + opts Options + wg sync.WaitGroup + jobs chan target + visited *hashmap.Map[string, struct{}] + modules *hashmap.Map[string, Module] + + canRequestHandlers []func(url string, depth int) bool + onRequestHandlers []func(*Request) + onResponseHandlers []func(*Response) + onCompleteHandlers []func() + transport func(*http.Request) (*http.Response, error) } -func (s *Scraper) init() { - s.visited = hashmap.New[string, struct{}]() - s.wg = &sync.WaitGroup{} - s.jobs = make(chan target, 1024) - s.results = make(chan ScrapeResult) +func NewScraper() *Scraper { + s := &Scraper{ + jobs: make(chan target, 1024), + visited: hashmap.New[string, struct{}](), + modules: hashmap.New[string, Module](), + transport: func(r *http.Request) (*http.Response, error) { + r.Header.Set("User-Agent", "flyscrape/0.1") + return http.DefaultClient.Do(r) + }, + } + return s +} - if s.FetchFunc == nil { - s.FetchFunc = Fetch() +func (s *Scraper) LoadModule(mod Module) { + if v, ok := mod.(Transport); ok { + s.SetTransport(v.Transport) } - if s.ScrapeOptions.Proxy != "" { - s.FetchFunc = ProxiedFetch(s.ScrapeOptions.Proxy) + + if v, ok := mod.(CanRequest); ok { + s.CanRequest(v.CanRequest) } - if s.ScrapeOptions.Rate == 0 { - s.ScrapeOptions.Rate = 100 + if v, ok := mod.(OnRequest); ok { + s.OnRequest(v.OnRequest) } - if u, err := url.Parse(s.ScrapeOptions.URL); err == nil { - s.ScrapeOptions.AllowedDomains = append(s.ScrapeOptions.AllowedDomains, u.Host()) + if v, ok := mod.(OnResponse); ok { + s.OnResponse(v.OnResponse) } - for _, pat := range s.ScrapeOptions.AllowedURLs { - re, err := regexp.Compile(pat) - if err != nil { - continue - } - s.allowedURLsRE = append(s.allowedURLsRE, re) + if v, ok := mod.(OnLoad); ok { + v.OnLoad(s) } - for _, pat := range s.ScrapeOptions.BlockedURLs { - re, err := regexp.Compile(pat) - if err != nil { - continue - } - s.blockedURLsRE = append(s.blockedURLsRE, re) + if v, ok := mod.(OnComplete); ok { + s.OnComplete(v.OnComplete) } } -func (s *Scraper) Scrape() <-chan ScrapeResult { - s.init() - s.enqueueJob(s.ScrapeOptions.URL, s.ScrapeOptions.Depth) +func (s *Scraper) Visit(url string) { + s.enqueueJob(url, 0) +} - go s.worker() - go s.waitClose() +func (s *Scraper) MarkVisited(url string) { + s.visited.Insert(url, struct{}{}) +} - return s.results +func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) { + s.transport = f } -func (s *Scraper) worker() { - var ( - rate = time.Duration(float64(time.Second) / s.ScrapeOptions.Rate) - leakyjobs = leakychan(s.jobs, rate) - ) +func (s *Scraper) CanRequest(f func(url string, depth int) bool) { + s.canRequestHandlers = append(s.canRequestHandlers, f) +} - for job := range leakyjobs { - go func(job target) { - defer s.wg.Done() +func (s *Scraper) OnRequest(f func(req *Request)) { + s.onRequestHandlers = append(s.onRequestHandlers, f) +} - res := s.process(job) - if !res.omit() { - s.results <- res - } +func (s *Scraper) OnResponse(f func(resp *Response)) { + s.onResponseHandlers = append(s.onResponseHandlers, f) +} - if job.depth <= 0 { - return - } +func (s *Scraper) OnComplete(f func()) { + s.onCompleteHandlers = append(s.onCompleteHandlers, f) +} - for _, l := range res.Links { - if _, ok := s.visited.Get(l); ok { - continue - } +func (s *Scraper) Run() { + go s.worker() + s.wg.Wait() + close(s.jobs) + + for _, handler := range s.onCompleteHandlers { + handler() + } +} - allowed := s.isDomainAllowed(l) && s.isURLAllowed(l) - if !allowed { - continue +func (s *Scraper) worker() { + for job := range s.jobs { + go func(job target) { + defer s.wg.Done() + + for _, handler := range s.canRequestHandlers { + if !handler(job.url, job.depth) { + return } + } - s.enqueueJob(l, job.depth-1) + res, html := s.process(job) + for _, handler := range s.onResponseHandlers { + handler(&Response{ + ScrapeResult: res, + HTML: html, + Visit: func(url string) { + s.enqueueJob(url, job.depth+1) + }, + }) } }(job) } } -func (s *Scraper) process(job target) (res ScrapeResult) { +func (s *Scraper) process(job target) (res ScrapeResult, html string) { res.URL = job.url res.Timestamp = time.Now() - html, err := s.FetchFunc(job.url) + req, err := http.NewRequest(http.MethodGet, job.url, nil) if err != nil { res.Error = err return } - res.Links = links(html, job.url) - res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url}) + for _, handler := range s.onRequestHandlers { + handler(&Request{URL: job.url, Depth: job.depth}) + } + + resp, err := s.transport(req) if err != nil { res.Error = err return } + defer resp.Body.Close() - return -} - -func (s *Scraper) enqueueJob(url string, depth int) { - s.wg.Add(1) - select { - case s.jobs <- target{url: url, depth: depth}: - s.visited.Set(url, struct{}{}) - default: - log.Println("queue is full, can't add url:", url) - s.wg.Done() - } -} - -func (s *Scraper) isDomainAllowed(rawurl string) bool { - u, err := url.Parse(rawurl) + body, err := io.ReadAll(resp.Body) if err != nil { - return false + res.Error = err + return } - host := u.Host() - ok := false + html = string(body) - for _, domain := range s.ScrapeOptions.AllowedDomains { - if domain == "*" || host == domain { - ok = true - break - } - } - - for _, domain := range s.ScrapeOptions.BlockedDomains { - if host == domain { - ok = false - break + if s.ScrapeFunc != nil { + res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url}) + if err != nil { + res.Error = err + return } } - return ok + return } -func (s *Scraper) isURLAllowed(rawurl string) bool { - // allow root url - if rawurl == s.ScrapeOptions.URL { - return true - } - - // allow if no filter is set - if len(s.allowedURLsRE) == 0 && len(s.blockedURLsRE) == 0 { - return true - } - - ok := false - if len(s.allowedURLsRE) == 0 { - ok = true - } - - for _, re := range s.allowedURLsRE { - if re.MatchString(rawurl) { - ok = true - break - } +func (s *Scraper) enqueueJob(url string, depth int) { + if _, ok := s.visited.Get(url); ok { + return } - for _, re := range s.blockedURLsRE { - if re.MatchString(rawurl) { - ok = false - break - } + s.wg.Add(1) + select { + case s.jobs <- target{url: url, depth: depth}: + s.MarkVisited(url) + default: + log.Println("queue is full, can't add url:", url) + s.wg.Done() } - - return ok } -func (s *Scraper) waitClose() { - s.wg.Wait() - close(s.jobs) - close(s.results) -} - -func links(html string, origin string) []string { +func ParseLinks(html string, origin string) []string { var links []string doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) if err != nil { @@ -285,26 +274,3 @@ func isValidLink(link *url.Url) bool { return true } - -func leakychan[T any](in chan T, rate time.Duration) chan T { - ticker := time.NewTicker(rate) - sem := make(chan struct{}, 1) - c := make(chan T) - - go func() { - for range ticker.C { - sem <- struct{}{} - } - }() - - go func() { - for v := range in { - <-sem - c <- v - } - ticker.Stop() - close(c) - }() - - return c -} diff --git a/scrape_test.go b/scrape_test.go deleted file mode 100644 index 6f7174e..0000000 --- a/scrape_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package flyscrape_test - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/philippta/flyscrape" - "github.com/stretchr/testify/require" -) - -func TestScrapeFollowLinks(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/foo/bar", - Depth: 1, - AllowedDomains: []string{"www.google.com"}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `Baz - Baz - Google`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 4) - require.Contains(t, urls, "http://www.example.com/baz") - require.Contains(t, urls, "http://www.example.com/foo/bar") - require.Contains(t, urls, "http://www.example.com/foo/baz") - require.Contains(t, urls, "http://www.google.com/") -} - -func TestScrapeDepth(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 2, - AllowedDomains: []string{"*"}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - switch url { - case "http://www.example.com/": - return `Google`, nil - case "http://www.google.com/": - return `DuckDuckGo`, nil - case "http://www.duckduckgo.com/": - return `Example`, nil - } - return "", nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 3) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.google.com/") - require.Contains(t, urls, "http://www.duckduckgo.com/") -} - -func TestScrapeAllowedDomains(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - AllowedDomains: []string{"www.google.com"}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `Google - DuckDuckGo`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 2) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.google.com/") -} - -func TestScrapeAllowedDomainsAll(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - AllowedDomains: []string{"*"}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `Google - DuckDuckGo`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 3) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.duckduckgo.com/") - require.Contains(t, urls, "http://www.google.com/") -} - -func TestScrapeBlockedDomains(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - AllowedDomains: []string{"*"}, - BlockedDomains: []string{"www.google.com"}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `Google - DuckDuckGo`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 2) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.duckduckgo.com/") -} - -func TestScrapeAllowedURLs(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `123 - ABC - bar - barz`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 3) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.example.com/foo?id=123") - require.Contains(t, urls, "http://www.example.com/bar") -} - -func TestScrapeBlockedURLs(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `123 - ABC - bar - barz`, nil - }, - } - - urls := make(map[string]struct{}) - for res := range scr.Scrape() { - urls[res.URL] = struct{}{} - } - - require.Len(t, urls, 3) - require.Contains(t, urls, "http://www.example.com/") - require.Contains(t, urls, "http://www.example.com/foo?id=ABC") - require.Contains(t, urls, "http://www.example.com/barz") -} - -func TestScrapeRate(t *testing.T) { - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Depth: 1, - Rate: 100, // every 10ms - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - FetchFunc: func(url string) (string, error) { - return `foo`, nil - }, - } - - res := scr.Scrape() - - start := time.Now() - <-res - first := time.Now().Add(-10 * time.Millisecond) - <-res - second := time.Now().Add(-20 * time.Millisecond) - - require.Less(t, first.Sub(start), 2*time.Millisecond) - require.Less(t, second.Sub(start), 2*time.Millisecond) -} - -func TestScrapeProxy(t *testing.T) { - proxyCalled := false - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxyCalled = true - w.Write([]byte(`Google`)) - })) - - scr := flyscrape.Scraper{ - ScrapeOptions: flyscrape.ScrapeOptions{ - URL: "http://www.example.com/", - Proxy: proxy.URL, - }, - ScrapeFunc: func(params flyscrape.ScrapeParams) (any, error) { - return "foobar", nil - }, - } - - res := <-scr.Scrape() - - require.True(t, proxyCalled) - require.Equal(t, "http://www.example.com/", res.URL) -} diff --git a/utils.go b/utils.go index 89ce3f1..8b52e76 100644 --- a/utils.go +++ b/utils.go @@ -26,3 +26,7 @@ func Print(v any, prefix string) string { enc.Encode(v) return prefix + strings.TrimSuffix(buf.String(), "\n") } + +func ParseOptions(opts Options, v any) { + json.Unmarshal(opts, v) +} -- cgit v1.2.3