From 1fc497fbdc79a43c62ac2e8eaf4827752dbeef8e Mon Sep 17 00:00:00 2001 From: Philipp Tanlak Date: Thu, 5 Oct 2023 14:53:37 +0200 Subject: Refactor codebase into modules --- modules/cache/cache.go | 78 +++++++++++++++++++++++++++++++ modules/cache/cache_test.go | 38 +++++++++++++++ modules/depth/depth.go | 15 ++++-- modules/depth/depth_test.go | 40 +++++++++------- modules/domainfilter/domainfilter.go | 32 ++++++++++--- modules/domainfilter/domainfilter_test.go | 69 +++++++++++++++++---------- modules/followlinks/followlinks.go | 60 ++++++++++++++++++++++-- modules/followlinks/followlinks_test.go | 26 +++++++---- modules/hook/hook.go | 78 +++++++++++++++++++++++++++++++ modules/jsonprint/jsonprint.go | 65 ++++++++++++++++++++++++++ modules/jsonprinter/jsonprinter.go | 57 ---------------------- modules/ratelimit/ratelimit.go | 35 +++++++++++--- modules/ratelimit/ratelimit_test.go | 18 ++++--- modules/starturl/starturl.go | 22 +++++++-- modules/starturl/starturl_test.go | 20 +++++--- modules/urlfilter/urlfilter.go | 35 ++++++++++---- modules/urlfilter/urlfilter_test.go | 55 ++++++++++++++-------- 17 files changed, 571 insertions(+), 172 deletions(-) create mode 100644 modules/cache/cache.go create mode 100644 modules/cache/cache_test.go create mode 100644 modules/hook/hook.go create mode 100644 modules/jsonprint/jsonprint.go delete mode 100644 modules/jsonprinter/jsonprinter.go (limited to 'modules') diff --git a/modules/cache/cache.go b/modules/cache/cache.go new file mode 100644 index 0000000..1a321be --- /dev/null +++ b/modules/cache/cache.go @@ -0,0 +1,78 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cache + +import ( + "bufio" + "bytes" + "net/http" + "net/http/httputil" + + "github.com/cornelk/hashmap" + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(Module{}) +} + +type Module struct { + Cache string `json:"cache"` + + cache *hashmap.Map[string, []byte] +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "cache", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(flyscrape.Context) { + if m.disabled() { + return + } + if m.cache == nil { + m.cache = hashmap.New[string, []byte]() + } +} + +func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { + if m.disabled() { + return t + } + + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + key := cacheKey(r) + + if b, ok := m.cache.Get(key); ok { + if resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(b)), r); err == nil { + return resp, nil + } + } + + resp, err := t.RoundTrip(r) + if err != nil { + return resp, err + } + + encoded, err := httputil.DumpResponse(resp, true) + if err != nil { + return resp, err + } + + m.cache.Set(key, encoded) + return resp, nil + }) +} + +func (m *Module) disabled() bool { + return m.Cache == "" +} + +func cacheKey(r *http.Request) string { + return r.Method + " " + r.URL.String() +} diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go new file mode 100644 index 0000000..4565e00 --- /dev/null +++ b/modules/cache/cache_test.go @@ -0,0 +1,38 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cache_test + +import ( + "net/http" + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/cache" + "github.com/philippta/flyscrape/modules/hook" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + cachemod := &cache.Module{Cache: "memory"} + calls := 0 + + for i := 0; i < 2; i++ { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + calls++ + return flyscrape.MockResponse(200, "foo") + }) + }, + }) + scraper.LoadModule(cachemod) + scraper.Run() + } + + require.Equal(t, 1, calls) +} diff --git a/modules/depth/depth.go b/modules/depth/depth.go index 0cfbc71..866f5ae 100644 --- a/modules/depth/depth.go +++ b/modules/depth/depth.go @@ -9,15 +9,22 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { Depth int `json:"depth"` } -func (m *Module) CanRequest(url string, depth int) bool { - return depth <= m.Depth +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "depth", + New: func() flyscrape.Module { return new(Module) }, + } } -var _ flyscrape.CanRequest = (*Module)(nil) +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + return r.Depth <= m.Depth +} + +var _ flyscrape.RequestValidator = (*Module)(nil) diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go index c9afd6f..10b67e9 100644 --- a/modules/depth/depth_test.go +++ b/modules/depth/depth_test.go @@ -6,36 +6,44 @@ package depth_test import ( "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/depth" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestDepth(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) scraper.LoadModule(&depth.Module{Depth: 2}) - - scraper.SetTransport(func(r *http.Request) (*http.Response, error) { - switch r.URL.String() { - case "http://www.example.com": - return flyscrape.MockResponse(200, `Google`) - case "http://www.google.com": - return flyscrape.MockResponse(200, `DuckDuckGo`) - case "http://www.duckduckgo.com": - return flyscrape.MockResponse(200, `Example`) - } - return flyscrape.MockResponse(200, "") - }) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "http://www.example.com": + return flyscrape.MockResponse(200, `Google`) + case "http://www.google.com": + return flyscrape.MockResponse(200, `DuckDuckGo`) + case "http://www.duckduckgo.com": + return flyscrape.MockResponse(200, `Example`) + } + return flyscrape.MockResponse(200, "") + }) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/domainfilter/domainfilter.go b/modules/domainfilter/domainfilter.go index ba9ebe6..e8691d3 100644 --- a/modules/domainfilter/domainfilter.go +++ b/modules/domainfilter/domainfilter.go @@ -10,23 +10,39 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { URL string `json:"url"` AllowedDomains []string `json:"allowedDomains"` BlockedDomains []string `json:"blockedDomains"` + + active bool } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "domainfilter", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.URL == "" { + return + } if u, err := url.Parse(m.URL); err == nil { m.AllowedDomains = append(m.AllowedDomains, u.Host()) } } -func (m *Module) CanRequest(rawurl string, depth int) bool { - u, err := url.Parse(rawurl) +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + if m.disabled() { + return true + } + + u, err := url.Parse(r.URL) if err != nil { return false } @@ -51,7 +67,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { return ok } +func (m *Module) disabled() bool { + return len(m.AllowedDomains) == 0 && len(m.BlockedDomains) == 0 +} + var ( - _ flyscrape.CanRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) + _ flyscrape.RequestValidator = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) ) diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go index 884a89f..a1c8401 100644 --- a/modules/domainfilter/domainfilter_test.go +++ b/modules/domainfilter/domainfilter_test.go @@ -5,16 +5,22 @@ package domainfilter_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/domainfilter" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestDomainfilterAllowed(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -22,14 +28,17 @@ func TestDomainfilterAllowed(t *testing.T) { URL: "http://www.example.com", AllowedDomains: []string{"www.google.com"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - Google - DuckDuckGo`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + Google + DuckDuckGo`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -40,6 +49,9 @@ func TestDomainfilterAllowed(t *testing.T) { } func TestDomainfilterAllowedAll(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -47,14 +59,17 @@ func TestDomainfilterAllowedAll(t *testing.T) { URL: "http://www.example.com", AllowedDomains: []string{"*"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - Google - DuckDuckGo`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + Google + DuckDuckGo`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -66,6 +81,9 @@ func TestDomainfilterAllowedAll(t *testing.T) { } func TestDomainfilterBlocked(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -74,14 +92,17 @@ func TestDomainfilterBlocked(t *testing.T) { AllowedDomains: []string{"*"}, BlockedDomains: []string{"www.google.com"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - Google - DuckDuckGo`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + Google + DuckDuckGo`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/followlinks/followlinks.go b/modules/followlinks/followlinks.go index 99d6cee..c53f167 100644 --- a/modules/followlinks/followlinks.go +++ b/modules/followlinks/followlinks.go @@ -5,19 +5,71 @@ package followlinks import ( + "net/url" + "strings" + + "github.com/PuerkitoBio/goquery" "github.com/philippta/flyscrape" ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct{} -func (m *Module) OnResponse(resp *flyscrape.Response) { - for _, link := range flyscrape.ParseLinks(string(resp.Body), resp.Request.URL) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "followlinks", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) ReceiveResponse(resp *flyscrape.Response) { + for _, link := range parseLinks(string(resp.Body), resp.Request.URL) { resp.Visit(link) } } -var _ flyscrape.OnResponse = (*Module)(nil) +func parseLinks(html string, origin string) []string { + var links []string + doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) + if err != nil { + return nil + } + + originurl, err := url.Parse(origin) + if err != nil { + return nil + } + + uniqueLinks := make(map[string]bool) + doc.Find("a").Each(func(i int, s *goquery.Selection) { + link, _ := s.Attr("href") + + parsedLink, err := originurl.Parse(link) + + if err != nil || !isValidLink(parsedLink) { + return + } + + absLink := parsedLink.String() + + if !uniqueLinks[absLink] { + links = append(links, absLink) + uniqueLinks[absLink] = true + } + }) + + return links +} + +func isValidLink(link *url.URL) bool { + if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" { + return false + } + + return true +} + +var _ flyscrape.ResponseReceiver = (*Module)(nil) diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go index 18c8ceb..0a628c3 100644 --- a/modules/followlinks/followlinks_test.go +++ b/modules/followlinks/followlinks_test.go @@ -5,27 +5,37 @@ package followlinks_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestFollowLinks(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) scraper.LoadModule(&followlinks.Module{}) - scraper.SetTransport(flyscrape.MockTransport(200, ` - Baz - Baz - Google`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + Baz + Baz + Google`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/hook/hook.go b/modules/hook/hook.go new file mode 100644 index 0000000..4484f47 --- /dev/null +++ b/modules/hook/hook.go @@ -0,0 +1,78 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hook + +import ( + "net/http" + + "github.com/philippta/flyscrape" +) + +type Module struct { + AdaptTransportFn func(http.RoundTripper) http.RoundTripper + ValidateRequestFn func(*flyscrape.Request) bool + BuildRequestFn func(*flyscrape.Request) + ReceiveResponseFn func(*flyscrape.Response) + ProvisionFn func(flyscrape.Context) + FinalizeFn func() +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "hook", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { + if m.AdaptTransportFn == nil { + return t + } + return m.AdaptTransportFn(t) +} + +func (m Module) ValidateRequest(r *flyscrape.Request) bool { + if m.ValidateRequestFn == nil { + return true + } + return m.ValidateRequestFn(r) +} + +func (m Module) BuildRequest(r *flyscrape.Request) { + if m.BuildRequestFn == nil { + return + } + m.BuildRequestFn(r) +} + +func (m Module) ReceiveResponse(r *flyscrape.Response) { + if m.ReceiveResponseFn == nil { + return + } + m.ReceiveResponseFn(r) +} + +func (m Module) Provision(ctx flyscrape.Context) { + if m.ProvisionFn == nil { + return + } + m.ProvisionFn(ctx) +} + +func (m Module) Finalize() { + if m.FinalizeFn == nil { + return + } + m.FinalizeFn() +} + +var ( + _ flyscrape.TransportAdapter = Module{} + _ flyscrape.RequestValidator = Module{} + _ flyscrape.RequestBuilder = Module{} + _ flyscrape.ResponseReceiver = Module{} + _ flyscrape.Provisioner = Module{} + _ flyscrape.Finalizer = Module{} +) diff --git a/modules/jsonprint/jsonprint.go b/modules/jsonprint/jsonprint.go new file mode 100644 index 0000000..29d3375 --- /dev/null +++ b/modules/jsonprint/jsonprint.go @@ -0,0 +1,65 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package jsonprint + +import ( + "fmt" + "time" + + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(Module{}) +} + +type Module struct { + once bool +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "jsonprint", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) ReceiveResponse(resp *flyscrape.Response) { + if resp.Error == nil && resp.Data == nil { + return + } + + if !m.once { + fmt.Println("[") + m.once = true + } else { + fmt.Println(",") + } + + o := output{ + URL: resp.Request.URL, + Data: resp.Data, + Error: resp.Error, + Timestamp: time.Now(), + } + + fmt.Print(flyscrape.Prettify(o, " ")) +} + +func (m *Module) Finalize() { + fmt.Println("\n]") +} + +type output struct { + URL string `json:"url,omitempty"` + Data any `json:"data,omitempty"` + Error error `json:"error,omitempty"` + Timestamp time.Time `json:"timestamp,omitempty"` +} + +var ( + _ flyscrape.ResponseReceiver = (*Module)(nil) + _ flyscrape.Finalizer = (*Module)(nil) +) diff --git a/modules/jsonprinter/jsonprinter.go b/modules/jsonprinter/jsonprinter.go deleted file mode 100644 index 3026f29..0000000 --- a/modules/jsonprinter/jsonprinter.go +++ /dev/null @@ -1,57 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package jsonprinter - -import ( - "fmt" - "time" - - "github.com/philippta/flyscrape" -) - -func init() { - flyscrape.RegisterModule(new(Module)) -} - -type Module struct { - first bool -} - -func (m *Module) OnResponse(resp *flyscrape.Response) { - if resp.Error == nil && resp.Data == nil { - return - } - - if m.first { - fmt.Println("[") - } else { - fmt.Println(",") - } - - o := output{ - URL: resp.Request.URL, - Data: resp.Data, - Error: resp.Error, - Timestamp: time.Now(), - } - - fmt.Print(flyscrape.PrettyPrint(o, " ")) -} - -func (m *Module) OnComplete() { - fmt.Println("\n]") -} - -type output struct { - URL string `json:"url,omitempty"` - Data any `json:"data,omitempty"` - Error error `json:"error,omitempty"` - Timestamp time.Time `json:"timestamp,omitempty"` -} - -var ( - _ flyscrape.OnResponse = (*Module)(nil) - _ flyscrape.OnComplete = (*Module)(nil) -) diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go index be622f6..9588db3 100644 --- a/modules/ratelimit/ratelimit.go +++ b/modules/ratelimit/ratelimit.go @@ -11,7 +11,7 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { @@ -21,7 +21,18 @@ type Module struct { semaphore chan struct{} } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "ratelimit", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.disabled() { + return + } + rate := time.Duration(float64(time.Second) / m.Rate) m.ticker = time.NewTicker(rate) @@ -34,16 +45,26 @@ func (m *Module) OnLoad(v flyscrape.Visitor) { }() } -func (m *Module) OnRequest(_ *flyscrape.Request) { +func (m *Module) BuildRequest(_ *flyscrape.Request) { + if m.disabled() { + return + } <-m.semaphore } -func (m *Module) OnComplete() { +func (m *Module) Finalize() { + if m.disabled() { + return + } m.ticker.Stop() } +func (m *Module) disabled() bool { + return m.Rate == 0 +} + var ( - _ flyscrape.OnRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) - _ flyscrape.OnComplete = (*Module)(nil) + _ flyscrape.RequestBuilder = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) + _ flyscrape.Finalizer = (*Module)(nil) ) diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go index 5e91f8f..ffd061c 100644 --- a/modules/ratelimit/ratelimit_test.go +++ b/modules/ratelimit/ratelimit_test.go @@ -5,33 +5,37 @@ package ratelimit_test import ( + "net/http" "testing" "time" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/ratelimit" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestRatelimit(t *testing.T) { + var times []time.Time + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) scraper.LoadModule(&ratelimit.Module{ Rate: 100, }) - - scraper.SetTransport(flyscrape.MockTransport(200, `foo`)) - - var times []time.Time - scraper.OnRequest(func(req *flyscrape.Request) { - times = append(times, time.Now()) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, `foo`) + }, + BuildRequestFn: func(r *flyscrape.Request) { + times = append(times, time.Now()) + }, }) start := time.Now() - scraper.Run() first := times[0].Add(-10 * time.Millisecond) diff --git a/modules/starturl/starturl.go b/modules/starturl/starturl.go index 109d28f..9e3ec31 100644 --- a/modules/starturl/starturl.go +++ b/modules/starturl/starturl.go @@ -9,15 +9,29 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { URL string `json:"url"` } -func (m *Module) OnLoad(v flyscrape.Visitor) { - v.Visit(m.URL) +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "starturl", + New: func() flyscrape.Module { return new(Module) }, + } } -var _ flyscrape.OnLoad = (*Module)(nil) +func (m *Module) Provision(ctx flyscrape.Context) { + if m.disabled() { + return + } + ctx.Visit(m.URL) +} + +func (m *Module) disabled() bool { + return m.URL == "" +} + +var _ flyscrape.Provisioner = (*Module)(nil) diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go index 6fab776..86e4ad7 100644 --- a/modules/starturl/starturl_test.go +++ b/modules/starturl/starturl_test.go @@ -5,23 +5,29 @@ package starturl_test import ( + "net/http" "testing" "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestStartURL(t *testing.T) { - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.SetTransport(flyscrape.MockTransport(200, "")) - var url string var depth int - scraper.OnRequest(func(req *flyscrape.Request) { - url = req.URL - depth = req.Depth + + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, "") + }, + BuildRequestFn: func(r *flyscrape.Request) { + url = r.URL + depth = r.Depth + }, }) scraper.Run() diff --git a/modules/urlfilter/urlfilter.go b/modules/urlfilter/urlfilter.go index 00a4bd2..1297c35 100644 --- a/modules/urlfilter/urlfilter.go +++ b/modules/urlfilter/urlfilter.go @@ -11,7 +11,7 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { @@ -23,7 +23,18 @@ type Module struct { blockedURLsRE []*regexp.Regexp } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "urlfilter", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.disabled() { + return + } + for _, pat := range m.AllowedURLs { re, err := regexp.Compile(pat) if err != nil { @@ -41,9 +52,13 @@ func (m *Module) OnLoad(v flyscrape.Visitor) { } } -func (m *Module) CanRequest(rawurl string, depth int) bool { +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + if m.disabled() { + return true + } + // allow root url - if rawurl == m.URL { + if r.URL == m.URL { return true } @@ -58,14 +73,14 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { } for _, re := range m.allowedURLsRE { - if re.MatchString(rawurl) { + if re.MatchString(r.URL) { ok = true break } } for _, re := range m.blockedURLsRE { - if re.MatchString(rawurl) { + if re.MatchString(r.URL) { ok = false break } @@ -74,7 +89,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { return ok } +func (m *Module) disabled() bool { + return len(m.AllowedURLs) == 0 && len(m.BlockedURLs) == 0 +} + var ( - _ flyscrape.CanRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) + _ flyscrape.RequestValidator = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) ) diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go index e383a32..9ebb8a5 100644 --- a/modules/urlfilter/urlfilter_test.go +++ b/modules/urlfilter/urlfilter_test.go @@ -5,16 +5,22 @@ package urlfilter_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/philippta/flyscrape/modules/urlfilter" "github.com/stretchr/testify/require" ) func TestURLFilterAllowed(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) scraper.LoadModule(&followlinks.Module{}) @@ -22,16 +28,19 @@ func TestURLFilterAllowed(t *testing.T) { URL: "http://www.example.com/", AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - 123 - ABC - bar - barz`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + 123 + ABC + bar + barz`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -43,6 +52,9 @@ func TestURLFilterAllowed(t *testing.T) { } func TestURLFilterBlocked(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) scraper.LoadModule(&followlinks.Module{}) @@ -50,16 +62,19 @@ func TestURLFilterBlocked(t *testing.T) { URL: "http://www.example.com/", BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - 123 - ABC - bar - barz`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + 123 + ABC + bar + barz`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() -- cgit v1.2.3