diff options
| author | Philipp Tanlak <philipp.tanlak@gmail.com> | 2023-10-05 14:53:37 +0200 |
|---|---|---|
| committer | Philipp Tanlak <philipp.tanlak@gmail.com> | 2023-10-05 14:53:37 +0200 |
| commit | 1fc497fbdc79a43c62ac2e8eaf4827752dbeef8e (patch) | |
| tree | 67738e213ef97f249bdfa0f1bddda0839192cb77 | |
| parent | bd9e7f7acfd855d4685aa4544169c0e29cdbf205 (diff) | |
Refactor codebase into modules
27 files changed, 723 insertions, 480 deletions
diff --git a/cmd/flyscrape/dev.go b/cmd/flyscrape/dev.go index 95c627e..fba3fba 100644 --- a/cmd/flyscrape/dev.go +++ b/cmd/flyscrape/dev.go @@ -17,7 +17,6 @@ type DevCommand struct{} func (c *DevCommand) Run(args []string) error { fs := flag.NewFlagSet("flyscrape-dev", flag.ContinueOnError) - proxy := fs.String("proxy", "", "proxy") fs.Usage = c.Usage if err := fs.Parse(args); err != nil { @@ -28,50 +27,25 @@ func (c *DevCommand) Run(args []string) error { return fmt.Errorf("too many arguments") } - var fetch flyscrape.FetchFunc - if *proxy != "" { - fetch = flyscrape.ProxiedFetch(*proxy) - } else { - fetch = flyscrape.Fetch() - } - - fetch = flyscrape.CachedFetch(fetch) script := fs.Arg(0) err := flyscrape.Watch(script, func(s string) error { cfg, scrape, err := flyscrape.Compile(s) if err != nil { - screen.Clear() - screen.MoveTopLeft() - - if errs, ok := err.(interface{ Unwrap() []error }); ok { - for _, err := range errs.Unwrap() { - log.Printf("%s:%v\n", script, err) - } - } else { - log.Println(err) - } - - // ignore compilation errors + printCompileErr(script, err) return nil } scraper := flyscrape.NewScraper() scraper.ScrapeFunc = scrape + flyscrape.LoadModules(scraper, cfg) + scraper.DisableModule("followlinks") + screen.Clear() + screen.MoveTopLeft() scraper.Run() - scraper.OnResponse(func(resp *flyscrape.Response) { - screen.Clear() - screen.MoveTopLeft() - if resp.Error != nil { - log.Println(resp.Error) - return - } - fmt.Println(flyscrape.PrettyPrint(resp.Data, "")) - }) - return nil }) if err != nil && err != flyscrape.StopWatch { @@ -97,3 +71,16 @@ Examples: $ flyscrape dev example.js `[1:]) } + +func printCompileErr(script string, err error) { + screen.Clear() + screen.MoveTopLeft() + + if errs, ok := err.(interface{ Unwrap() []error }); ok { + for _, err := range errs.Unwrap() { + log.Printf("%s:%v\n", script, err) + } + } else { + log.Println(err) + } +} diff --git a/cmd/flyscrape/main.go b/cmd/flyscrape/main.go index bac411e..81d0e2b 100644 --- a/cmd/flyscrape/main.go +++ b/cmd/flyscrape/main.go @@ -12,10 +12,11 @@ import ( "os" "strings" + _ "github.com/philippta/flyscrape/modules/cache" _ "github.com/philippta/flyscrape/modules/depth" _ "github.com/philippta/flyscrape/modules/domainfilter" _ "github.com/philippta/flyscrape/modules/followlinks" - _ "github.com/philippta/flyscrape/modules/jsonprinter" + _ "github.com/philippta/flyscrape/modules/jsonprint" _ "github.com/philippta/flyscrape/modules/ratelimit" _ "github.com/philippta/flyscrape/modules/starturl" _ "github.com/philippta/flyscrape/modules/urlfilter" @@ -66,7 +67,7 @@ Usage: flyscrape <command> [arguments] Commands: - + new creates a sample scraping script run runs a scraping script dev watches and re-runs a scraping script diff --git a/cmd/flyscrape/run.go b/cmd/flyscrape/run.go index 4580e6d..b467abe 100644 --- a/cmd/flyscrape/run.go +++ b/cmd/flyscrape/run.go @@ -12,6 +12,7 @@ import ( "time" "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/hook" ) type RunCommand struct{} @@ -41,14 +42,18 @@ func (c *RunCommand) Run(args []string) error { scraper := flyscrape.NewScraper() scraper.ScrapeFunc = scrape + flyscrape.LoadModules(scraper, cfg) count := 0 start := time.Now() - scraper.OnResponse(func(resp *flyscrape.Response) { - count++ + scraper.LoadModule(hook.Module{ + ReceiveResponseFn: func(r *flyscrape.Response) { + count++ + }, }) + scraper.Run() log.Printf("Scraped %d websites in %v\n", count, time.Since(start)) diff --git a/fetch.go b/fetch.go deleted file mode 100644 index d969a74..0000000 --- a/fetch.go +++ /dev/null @@ -1,89 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package flyscrape - -import ( - "crypto/tls" - "io" - "net/http" - "net/url" - - "github.com/cornelk/hashmap" -) - -const userAgent = "flyscrape/0.1" - -func ProxiedFetch(proxyURL string) FetchFunc { - pu, err := url.Parse(proxyURL) - if err != nil { - panic("invalid proxy url") - } - - client := http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(pu), - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - } - - return func(url string) (string, error) { - resp, err := client.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - html := string(body) - return html, nil - } -} - -func CachedFetch(fetch FetchFunc) FetchFunc { - cache := hashmap.New[string, string]() - - return func(url string) (string, error) { - if html, ok := cache.Get(url); ok { - return html, nil - } - - html, err := fetch(url) - if err != nil { - return "", err - } - - cache.Set(url, html) - return html, nil - } -} - -func Fetch() FetchFunc { - return func(url string) (string, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return "", err - } - - req.Header.Set("User-Agent", userAgent) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - html := string(body) - return html, nil - } -} diff --git a/fetch_test.go b/fetch_test.go deleted file mode 100644 index b32ac0f..0000000 --- a/fetch_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package flyscrape_test - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/philippta/flyscrape" - "github.com/stretchr/testify/require" -) - -func TestFetchFetch(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("foobar")) - })) - - fetch := flyscrape.Fetch() - - html, err := fetch(srv.URL) - require.NoError(t, err) - require.Equal(t, html, "foobar") -} - -func TestFetchCachedFetch(t *testing.T) { - numcalled := 0 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - numcalled++ - w.Write([]byte("foobar")) - })) - - fetch := flyscrape.CachedFetch(flyscrape.Fetch()) - - html, err := fetch(srv.URL) - require.NoError(t, err) - require.Equal(t, html, "foobar") - - html, err = fetch(srv.URL) - require.NoError(t, err) - require.Equal(t, html, "foobar") - - require.Equal(t, 1, numcalled) -} - -func TestFetchProxiedFetch(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, r.URL.String(), "http://example.com/foo") - w.Write([]byte("foobar")) - })) - - fetch := flyscrape.ProxiedFetch(srv.URL) - - html, err := fetch("http://example.com/foo") - require.NoError(t, err) - require.Equal(t, html, "foobar") -} - -func TestFetchUserAgent(t *testing.T) { - var userAgent string - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userAgent = r.Header.Get("User-Agent") - })) - - fetch := flyscrape.Fetch() - - _, err := fetch(srv.URL) - require.NoError(t, err) - require.Equal(t, "flyscrape/0.1", userAgent) -} @@ -1,6 +1,6 @@ module github.com/philippta/flyscrape -go 1.20 +go 1.21 require ( github.com/PuerkitoBio/goquery v1.8.1 @@ -18,6 +18,13 @@ import ( type Config []byte +type ScrapeParams struct { + HTML string + URL string +} + +type ScrapeFunc func(ScrapeParams) (any, error) + type TransformError struct { Line int Column int @@ -1,3 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package flyscrape import ( @@ -7,7 +11,7 @@ import ( "strings" ) -func MockTransport(statusCode int, html string) func(*http.Request) (*http.Response, error) { +func MockTransport(statusCode int, html string) RoundTripFunc { return func(*http.Request) (*http.Response, error) { return MockResponse(statusCode, html) } @@ -1,44 +1,101 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package flyscrape import ( "encoding/json" "net/http" + "sync" ) -type Module any +type Module interface { + ModuleInfo() ModuleInfo +} + +type ModuleInfo struct { + ID string + New func() Module +} -type Transport interface { - Transport(*http.Request) (*http.Response, error) +type TransportAdapter interface { + AdaptTransport(http.RoundTripper) http.RoundTripper } -type CanRequest interface { - CanRequest(url string, depth int) bool +type RequestValidator interface { + ValidateRequest(*Request) bool } -type OnRequest interface { - OnRequest(*Request) +type RequestBuilder interface { + BuildRequest(*Request) } -type OnResponse interface { - OnResponse(*Response) + +type ResponseReceiver interface { + ReceiveResponse(*Response) } -type OnLoad interface { - OnLoad(Visitor) +type Provisioner interface { + Provision(Context) } -type OnComplete interface { - OnComplete() +type Finalizer interface { + Finalize() } func RegisterModule(mod Module) { - globalModules = append(globalModules, mod) + modulesMu.Lock() + defer modulesMu.Unlock() + + id := mod.ModuleInfo().ID + if _, ok := modules[id]; ok { + panic("module with id: " + id + " already registered") + } + modules[mod.ModuleInfo().ID] = mod } func LoadModules(s *Scraper, cfg Config) { - for _, mod := range globalModules { - json.Unmarshal(cfg, mod) + modulesMu.RLock() + defer modulesMu.RUnlock() + + loaded := map[string]struct{}{} + + // load standard modules in order + for _, id := range moduleOrder { + mod := modules[id].ModuleInfo().New() + if err := json.Unmarshal(cfg, mod); err != nil { + panic("failed to decode config: " + err.Error()) + } + s.LoadModule(mod) + loaded[id] = struct{}{} + } + + // load custom modules + for id := range modules { + if _, ok := loaded[id]; ok { + continue + } + mod := modules[id].ModuleInfo().New() + if err := json.Unmarshal(cfg, mod); err != nil { + panic("failed to decode config: " + err.Error()) + } s.LoadModule(mod) + loaded[id] = struct{}{} } } -var globalModules = []Module{} +var ( + modules = map[string]Module{} + modulesMu sync.RWMutex + + moduleOrder = []string{ + "cache", + "starturl", + "followlinks", + "depth", + "domainfilter", + "urlfilter", + "ratelimit", + "jsonprint", + } +) diff --git a/modules/cache/cache.go b/modules/cache/cache.go new file mode 100644 index 0000000..1a321be --- /dev/null +++ b/modules/cache/cache.go @@ -0,0 +1,78 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cache + +import ( + "bufio" + "bytes" + "net/http" + "net/http/httputil" + + "github.com/cornelk/hashmap" + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(Module{}) +} + +type Module struct { + Cache string `json:"cache"` + + cache *hashmap.Map[string, []byte] +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "cache", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(flyscrape.Context) { + if m.disabled() { + return + } + if m.cache == nil { + m.cache = hashmap.New[string, []byte]() + } +} + +func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { + if m.disabled() { + return t + } + + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + key := cacheKey(r) + + if b, ok := m.cache.Get(key); ok { + if resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(b)), r); err == nil { + return resp, nil + } + } + + resp, err := t.RoundTrip(r) + if err != nil { + return resp, err + } + + encoded, err := httputil.DumpResponse(resp, true) + if err != nil { + return resp, err + } + + m.cache.Set(key, encoded) + return resp, nil + }) +} + +func (m *Module) disabled() bool { + return m.Cache == "" +} + +func cacheKey(r *http.Request) string { + return r.Method + " " + r.URL.String() +} diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go new file mode 100644 index 0000000..4565e00 --- /dev/null +++ b/modules/cache/cache_test.go @@ -0,0 +1,38 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cache_test + +import ( + "net/http" + "testing" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/cache" + "github.com/philippta/flyscrape/modules/hook" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + cachemod := &cache.Module{Cache: "memory"} + calls := 0 + + for i := 0; i < 2; i++ { + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + calls++ + return flyscrape.MockResponse(200, "foo") + }) + }, + }) + scraper.LoadModule(cachemod) + scraper.Run() + } + + require.Equal(t, 1, calls) +} diff --git a/modules/depth/depth.go b/modules/depth/depth.go index 0cfbc71..866f5ae 100644 --- a/modules/depth/depth.go +++ b/modules/depth/depth.go @@ -9,15 +9,22 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { Depth int `json:"depth"` } -func (m *Module) CanRequest(url string, depth int) bool { - return depth <= m.Depth +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "depth", + New: func() flyscrape.Module { return new(Module) }, + } } -var _ flyscrape.CanRequest = (*Module)(nil) +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + return r.Depth <= m.Depth +} + +var _ flyscrape.RequestValidator = (*Module)(nil) diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go index c9afd6f..10b67e9 100644 --- a/modules/depth/depth_test.go +++ b/modules/depth/depth_test.go @@ -6,36 +6,44 @@ package depth_test import ( "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/depth" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestDepth(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) scraper.LoadModule(&depth.Module{Depth: 2}) - - scraper.SetTransport(func(r *http.Request) (*http.Response, error) { - switch r.URL.String() { - case "http://www.example.com": - return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`) - case "http://www.google.com": - return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`) - case "http://www.duckduckgo.com": - return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`) - } - return flyscrape.MockResponse(200, "") - }) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "http://www.example.com": + return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`) + case "http://www.google.com": + return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + case "http://www.duckduckgo.com": + return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`) + } + return flyscrape.MockResponse(200, "") + }) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/domainfilter/domainfilter.go b/modules/domainfilter/domainfilter.go index ba9ebe6..e8691d3 100644 --- a/modules/domainfilter/domainfilter.go +++ b/modules/domainfilter/domainfilter.go @@ -10,23 +10,39 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { URL string `json:"url"` AllowedDomains []string `json:"allowedDomains"` BlockedDomains []string `json:"blockedDomains"` + + active bool } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "domainfilter", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.URL == "" { + return + } if u, err := url.Parse(m.URL); err == nil { m.AllowedDomains = append(m.AllowedDomains, u.Host()) } } -func (m *Module) CanRequest(rawurl string, depth int) bool { - u, err := url.Parse(rawurl) +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + if m.disabled() { + return true + } + + u, err := url.Parse(r.URL) if err != nil { return false } @@ -51,7 +67,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { return ok } +func (m *Module) disabled() bool { + return len(m.AllowedDomains) == 0 && len(m.BlockedDomains) == 0 +} + var ( - _ flyscrape.CanRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) + _ flyscrape.RequestValidator = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) ) diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go index 884a89f..a1c8401 100644 --- a/modules/domainfilter/domainfilter_test.go +++ b/modules/domainfilter/domainfilter_test.go @@ -5,16 +5,22 @@ package domainfilter_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/domainfilter" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestDomainfilterAllowed(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -22,14 +28,17 @@ func TestDomainfilterAllowed(t *testing.T) { URL: "http://www.example.com", AllowedDomains: []string{"www.google.com"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="http://www.google.com">Google</a> - <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="http://www.google.com">Google</a> + <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -40,6 +49,9 @@ func TestDomainfilterAllowed(t *testing.T) { } func TestDomainfilterAllowedAll(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -47,14 +59,17 @@ func TestDomainfilterAllowedAll(t *testing.T) { URL: "http://www.example.com", AllowedDomains: []string{"*"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="http://www.google.com">Google</a> - <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="http://www.google.com">Google</a> + <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -66,6 +81,9 @@ func TestDomainfilterAllowedAll(t *testing.T) { } func TestDomainfilterBlocked(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) @@ -74,14 +92,17 @@ func TestDomainfilterBlocked(t *testing.T) { AllowedDomains: []string{"*"}, BlockedDomains: []string{"www.google.com"}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="http://www.google.com">Google</a> - <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="http://www.google.com">Google</a> + <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/followlinks/followlinks.go b/modules/followlinks/followlinks.go index 99d6cee..c53f167 100644 --- a/modules/followlinks/followlinks.go +++ b/modules/followlinks/followlinks.go @@ -5,19 +5,71 @@ package followlinks import ( + "net/url" + "strings" + + "github.com/PuerkitoBio/goquery" "github.com/philippta/flyscrape" ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct{} -func (m *Module) OnResponse(resp *flyscrape.Response) { - for _, link := range flyscrape.ParseLinks(string(resp.Body), resp.Request.URL) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "followlinks", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) ReceiveResponse(resp *flyscrape.Response) { + for _, link := range parseLinks(string(resp.Body), resp.Request.URL) { resp.Visit(link) } } -var _ flyscrape.OnResponse = (*Module)(nil) +func parseLinks(html string, origin string) []string { + var links []string + doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) + if err != nil { + return nil + } + + originurl, err := url.Parse(origin) + if err != nil { + return nil + } + + uniqueLinks := make(map[string]bool) + doc.Find("a").Each(func(i int, s *goquery.Selection) { + link, _ := s.Attr("href") + + parsedLink, err := originurl.Parse(link) + + if err != nil || !isValidLink(parsedLink) { + return + } + + absLink := parsedLink.String() + + if !uniqueLinks[absLink] { + links = append(links, absLink) + uniqueLinks[absLink] = true + } + }) + + return links +} + +func isValidLink(link *url.URL) bool { + if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" { + return false + } + + return true +} + +var _ flyscrape.ResponseReceiver = (*Module)(nil) diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go index 18c8ceb..0a628c3 100644 --- a/modules/followlinks/followlinks_test.go +++ b/modules/followlinks/followlinks_test.go @@ -5,27 +5,37 @@ package followlinks_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestFollowLinks(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) scraper.LoadModule(&followlinks.Module{}) - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="/baz">Baz</a> - <a href="baz">Baz</a> - <a href="http://www.google.com">Google</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="/baz">Baz</a> + <a href="baz">Baz</a> + <a href="http://www.google.com">Google</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() diff --git a/modules/hook/hook.go b/modules/hook/hook.go new file mode 100644 index 0000000..4484f47 --- /dev/null +++ b/modules/hook/hook.go @@ -0,0 +1,78 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hook + +import ( + "net/http" + + "github.com/philippta/flyscrape" +) + +type Module struct { + AdaptTransportFn func(http.RoundTripper) http.RoundTripper + ValidateRequestFn func(*flyscrape.Request) bool + BuildRequestFn func(*flyscrape.Request) + ReceiveResponseFn func(*flyscrape.Response) + ProvisionFn func(flyscrape.Context) + FinalizeFn func() +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "hook", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { + if m.AdaptTransportFn == nil { + return t + } + return m.AdaptTransportFn(t) +} + +func (m Module) ValidateRequest(r *flyscrape.Request) bool { + if m.ValidateRequestFn == nil { + return true + } + return m.ValidateRequestFn(r) +} + +func (m Module) BuildRequest(r *flyscrape.Request) { + if m.BuildRequestFn == nil { + return + } + m.BuildRequestFn(r) +} + +func (m Module) ReceiveResponse(r *flyscrape.Response) { + if m.ReceiveResponseFn == nil { + return + } + m.ReceiveResponseFn(r) +} + +func (m Module) Provision(ctx flyscrape.Context) { + if m.ProvisionFn == nil { + return + } + m.ProvisionFn(ctx) +} + +func (m Module) Finalize() { + if m.FinalizeFn == nil { + return + } + m.FinalizeFn() +} + +var ( + _ flyscrape.TransportAdapter = Module{} + _ flyscrape.RequestValidator = Module{} + _ flyscrape.RequestBuilder = Module{} + _ flyscrape.ResponseReceiver = Module{} + _ flyscrape.Provisioner = Module{} + _ flyscrape.Finalizer = Module{} +) diff --git a/modules/jsonprinter/jsonprinter.go b/modules/jsonprint/jsonprint.go index 3026f29..29d3375 100644 --- a/modules/jsonprinter/jsonprinter.go +++ b/modules/jsonprint/jsonprint.go @@ -2,7 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package jsonprinter +package jsonprint import ( "fmt" @@ -12,20 +12,28 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { - first bool + once bool } -func (m *Module) OnResponse(resp *flyscrape.Response) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "jsonprint", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) ReceiveResponse(resp *flyscrape.Response) { if resp.Error == nil && resp.Data == nil { return } - if m.first { + if !m.once { fmt.Println("[") + m.once = true } else { fmt.Println(",") } @@ -37,10 +45,10 @@ func (m *Module) OnResponse(resp *flyscrape.Response) { Timestamp: time.Now(), } - fmt.Print(flyscrape.PrettyPrint(o, " ")) + fmt.Print(flyscrape.Prettify(o, " ")) } -func (m *Module) OnComplete() { +func (m *Module) Finalize() { fmt.Println("\n]") } @@ -52,6 +60,6 @@ type output struct { } var ( - _ flyscrape.OnResponse = (*Module)(nil) - _ flyscrape.OnComplete = (*Module)(nil) + _ flyscrape.ResponseReceiver = (*Module)(nil) + _ flyscrape.Finalizer = (*Module)(nil) ) diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go index be622f6..9588db3 100644 --- a/modules/ratelimit/ratelimit.go +++ b/modules/ratelimit/ratelimit.go @@ -11,7 +11,7 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { @@ -21,7 +21,18 @@ type Module struct { semaphore chan struct{} } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "ratelimit", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.disabled() { + return + } + rate := time.Duration(float64(time.Second) / m.Rate) m.ticker = time.NewTicker(rate) @@ -34,16 +45,26 @@ func (m *Module) OnLoad(v flyscrape.Visitor) { }() } -func (m *Module) OnRequest(_ *flyscrape.Request) { +func (m *Module) BuildRequest(_ *flyscrape.Request) { + if m.disabled() { + return + } <-m.semaphore } -func (m *Module) OnComplete() { +func (m *Module) Finalize() { + if m.disabled() { + return + } m.ticker.Stop() } +func (m *Module) disabled() bool { + return m.Rate == 0 +} + var ( - _ flyscrape.OnRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) - _ flyscrape.OnComplete = (*Module)(nil) + _ flyscrape.RequestBuilder = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) + _ flyscrape.Finalizer = (*Module)(nil) ) diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go index 5e91f8f..ffd061c 100644 --- a/modules/ratelimit/ratelimit_test.go +++ b/modules/ratelimit/ratelimit_test.go @@ -5,33 +5,37 @@ package ratelimit_test import ( + "net/http" "testing" "time" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/ratelimit" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestRatelimit(t *testing.T) { + var times []time.Time + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) scraper.LoadModule(&followlinks.Module{}) scraper.LoadModule(&ratelimit.Module{ Rate: 100, }) - - scraper.SetTransport(flyscrape.MockTransport(200, `<a href="foo">foo</a>`)) - - var times []time.Time - scraper.OnRequest(func(req *flyscrape.Request) { - times = append(times, time.Now()) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, `<a href="foo">foo</a>`) + }, + BuildRequestFn: func(r *flyscrape.Request) { + times = append(times, time.Now()) + }, }) start := time.Now() - scraper.Run() first := times[0].Add(-10 * time.Millisecond) diff --git a/modules/starturl/starturl.go b/modules/starturl/starturl.go index 109d28f..9e3ec31 100644 --- a/modules/starturl/starturl.go +++ b/modules/starturl/starturl.go @@ -9,15 +9,29 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { URL string `json:"url"` } -func (m *Module) OnLoad(v flyscrape.Visitor) { - v.Visit(m.URL) +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "starturl", + New: func() flyscrape.Module { return new(Module) }, + } } -var _ flyscrape.OnLoad = (*Module)(nil) +func (m *Module) Provision(ctx flyscrape.Context) { + if m.disabled() { + return + } + ctx.Visit(m.URL) +} + +func (m *Module) disabled() bool { + return m.URL == "" +} + +var _ flyscrape.Provisioner = (*Module)(nil) diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go index 6fab776..86e4ad7 100644 --- a/modules/starturl/starturl_test.go +++ b/modules/starturl/starturl_test.go @@ -5,23 +5,29 @@ package starturl_test import ( + "net/http" "testing" "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/stretchr/testify/require" ) func TestStartURL(t *testing.T) { - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.SetTransport(flyscrape.MockTransport(200, "")) - var url string var depth int - scraper.OnRequest(func(req *flyscrape.Request) { - url = req.URL - depth = req.Depth + + scraper := flyscrape.NewScraper() + scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, "") + }, + BuildRequestFn: func(r *flyscrape.Request) { + url = r.URL + depth = r.Depth + }, }) scraper.Run() diff --git a/modules/urlfilter/urlfilter.go b/modules/urlfilter/urlfilter.go index 00a4bd2..1297c35 100644 --- a/modules/urlfilter/urlfilter.go +++ b/modules/urlfilter/urlfilter.go @@ -11,7 +11,7 @@ import ( ) func init() { - flyscrape.RegisterModule(new(Module)) + flyscrape.RegisterModule(Module{}) } type Module struct { @@ -23,7 +23,18 @@ type Module struct { blockedURLsRE []*regexp.Regexp } -func (m *Module) OnLoad(v flyscrape.Visitor) { +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "urlfilter", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(v flyscrape.Context) { + if m.disabled() { + return + } + for _, pat := range m.AllowedURLs { re, err := regexp.Compile(pat) if err != nil { @@ -41,9 +52,13 @@ func (m *Module) OnLoad(v flyscrape.Visitor) { } } -func (m *Module) CanRequest(rawurl string, depth int) bool { +func (m *Module) ValidateRequest(r *flyscrape.Request) bool { + if m.disabled() { + return true + } + // allow root url - if rawurl == m.URL { + if r.URL == m.URL { return true } @@ -58,14 +73,14 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { } for _, re := range m.allowedURLsRE { - if re.MatchString(rawurl) { + if re.MatchString(r.URL) { ok = true break } } for _, re := range m.blockedURLsRE { - if re.MatchString(rawurl) { + if re.MatchString(r.URL) { ok = false break } @@ -74,7 +89,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool { return ok } +func (m *Module) disabled() bool { + return len(m.AllowedURLs) == 0 && len(m.BlockedURLs) == 0 +} + var ( - _ flyscrape.CanRequest = (*Module)(nil) - _ flyscrape.OnLoad = (*Module)(nil) + _ flyscrape.RequestValidator = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) ) diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go index e383a32..9ebb8a5 100644 --- a/modules/urlfilter/urlfilter_test.go +++ b/modules/urlfilter/urlfilter_test.go @@ -5,16 +5,22 @@ package urlfilter_test import ( + "net/http" + "sync" "testing" "github.com/philippta/flyscrape" "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" "github.com/philippta/flyscrape/modules/starturl" "github.com/philippta/flyscrape/modules/urlfilter" "github.com/stretchr/testify/require" ) func TestURLFilterAllowed(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) scraper.LoadModule(&followlinks.Module{}) @@ -22,16 +28,19 @@ func TestURLFilterAllowed(t *testing.T) { URL: "http://www.example.com/", AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="foo?id=123">123</a> - <a href="foo?id=ABC">ABC</a> - <a href="/bar">bar</a> - <a href="/barz">barz</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="foo?id=123">123</a> + <a href="foo?id=ABC">ABC</a> + <a href="/bar">bar</a> + <a href="/barz">barz</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -43,6 +52,9 @@ func TestURLFilterAllowed(t *testing.T) { } func TestURLFilterBlocked(t *testing.T) { + var urls []string + var mu sync.Mutex + scraper := flyscrape.NewScraper() scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) scraper.LoadModule(&followlinks.Module{}) @@ -50,16 +62,19 @@ func TestURLFilterBlocked(t *testing.T) { URL: "http://www.example.com/", BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, }) - - scraper.SetTransport(flyscrape.MockTransport(200, ` - <a href="foo?id=123">123</a> - <a href="foo?id=ABC">ABC</a> - <a href="/bar">bar</a> - <a href="/barz">barz</a>`)) - - var urls []string - scraper.OnRequest(func(req *flyscrape.Request) { - urls = append(urls, req.URL) + scraper.LoadModule(hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` + <a href="foo?id=123">123</a> + <a href="foo?id=ABC">ABC</a> + <a href="/bar">bar</a> + <a href="/barz">barz</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }) scraper.Run() @@ -9,28 +9,19 @@ import ( "log" "net/http" "net/http/cookiejar" - "net/url" - "strings" + "slices" "sync" - gourl "net/url" - - "github.com/PuerkitoBio/goquery" "github.com/cornelk/hashmap" ) -type ScrapeParams struct { - HTML string - URL string -} - -type ScrapeFunc func(ScrapeParams) (any, error) - type FetchFunc func(url string) (string, error) -type Visitor interface { +type Context interface { Visit(url string) MarkVisited(url string) + MarkUnvisited(url string) + DisableModule(id string) } type Request struct { @@ -57,62 +48,39 @@ type target struct { depth int } -type Scraper struct { - ScrapeFunc ScrapeFunc - - cfg Config - wg sync.WaitGroup - jobs chan target - visited *hashmap.Map[string, struct{}] - modules *hashmap.Map[string, Module] - cookieJar *cookiejar.Jar - - canRequestHandlers []func(url string, depth int) bool - onRequestHandlers []func(*Request) - onResponseHandlers []func(*Response) - onCompleteHandlers []func() - transport func(*http.Request) (*http.Response, error) -} - func NewScraper() *Scraper { - jar, _ := cookiejar.New(nil) - s := &Scraper{ + return &Scraper{ jobs: make(chan target, 1024), visited: hashmap.New[string, struct{}](), - modules: hashmap.New[string, Module](), - transport: func(r *http.Request) (*http.Response, error) { - r.Header.Set("User-Agent", "flyscrape/0.1") - return http.DefaultClient.Do(r) - }, - cookieJar: jar, } - return s } -func (s *Scraper) LoadModule(mod Module) { - if v, ok := mod.(Transport); ok { - s.SetTransport(v.Transport) - } +type Scraper struct { + ScrapeFunc ScrapeFunc - if v, ok := mod.(CanRequest); ok { - s.CanRequest(v.CanRequest) - } + wg sync.WaitGroup + jobs chan target + visited *hashmap.Map[string, struct{}] - if v, ok := mod.(OnRequest); ok { - s.OnRequest(v.OnRequest) - } + modules []Module + moduleIDs []string + client *http.Client +} - if v, ok := mod.(OnResponse); ok { - s.OnResponse(v.OnResponse) - } +func (s *Scraper) LoadModule(mod Module) { + id := mod.ModuleInfo().ID - if v, ok := mod.(OnLoad); ok { - v.OnLoad(s) - } + s.modules = append(s.modules, mod) + s.moduleIDs = append(s.moduleIDs, id) +} - if v, ok := mod.(OnComplete); ok { - s.OnComplete(v.OnComplete) +func (s *Scraper) DisableModule(id string) { + idx := slices.Index(s.moduleIDs, id) + if idx == -1 { + return } + s.modules = slices.Delete(s.modules, idx, idx+1) + s.moduleIDs = slices.Delete(s.moduleIDs, idx, idx+1) } func (s *Scraper) Visit(url string) { @@ -123,49 +91,47 @@ func (s *Scraper) MarkVisited(url string) { s.visited.Insert(url, struct{}{}) } -func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) { - s.transport = f -} - -func (s *Scraper) CanRequest(f func(url string, depth int) bool) { - s.canRequestHandlers = append(s.canRequestHandlers, f) -} - -func (s *Scraper) OnRequest(f func(req *Request)) { - s.onRequestHandlers = append(s.onRequestHandlers, f) -} - -func (s *Scraper) OnResponse(f func(resp *Response)) { - s.onResponseHandlers = append(s.onResponseHandlers, f) -} - -func (s *Scraper) OnComplete(f func()) { - s.onCompleteHandlers = append(s.onCompleteHandlers, f) +func (s *Scraper) MarkUnvisited(url string) { + s.visited.Del(url) } func (s *Scraper) Run() { - go s.worker() + for _, mod := range s.modules { + if v, ok := mod.(Provisioner); ok { + v.Provision(s) + } + } + + s.initClient() + go s.scrape() s.wg.Wait() close(s.jobs) - for _, handler := range s.onCompleteHandlers { - handler() + for _, mod := range s.modules { + if v, ok := mod.(Finalizer); ok { + v.Finalize() + } } } -func (s *Scraper) worker() { - for job := range s.jobs { - go func(job target) { - defer s.wg.Done() +func (s *Scraper) initClient() { + jar, _ := cookiejar.New(nil) + s.client = &http.Client{Jar: jar} - for _, handler := range s.canRequestHandlers { - if !handler(job.url, job.depth) { - return - } - } + for _, mod := range s.modules { + if v, ok := mod.(TransportAdapter); ok { + s.client.Transport = v.AdaptTransport(s.client.Transport) + } + } +} +func (s *Scraper) scrape() { + for job := range s.jobs { + job := job + go func() { s.process(job.url, job.depth) - }(job) + s.wg.Done() + }() } } @@ -173,8 +139,9 @@ func (s *Scraper) process(url string, depth int) { request := &Request{ Method: http.MethodGet, URL: url, - Headers: http.Header{}, - Cookies: s.cookieJar, + Headers: defaultHeaders(), + Cookies: s.client.Jar, + Depth: depth, } response := &Response{ @@ -184,11 +151,11 @@ func (s *Scraper) process(url string, depth int) { }, } - defer func() { - for _, handler := range s.onResponseHandlers { - handler(response) + for _, mod := range s.modules { + if v, ok := mod.(RequestBuilder); ok { + v.BuildRequest(request) } - }() + } req, err := http.NewRequest(request.Method, request.URL, nil) if err != nil { @@ -197,11 +164,23 @@ func (s *Scraper) process(url string, depth int) { } req.Header = request.Headers - for _, handler := range s.onRequestHandlers { - handler(request) + for _, mod := range s.modules { + if v, ok := mod.(RequestValidator); ok { + if !v.ValidateRequest(request) { + return + } + } } - resp, err := s.transport(req) + defer func() { + for _, mod := range s.modules { + if v, ok := mod.(ResponseReceiver); ok { + v.ReceiveResponse(response) + } + } + }() + + resp, err := s.client.Do(req) if err != nil { response.Error = err return @@ -241,43 +220,9 @@ func (s *Scraper) enqueueJob(url string, depth int) { } } -func ParseLinks(html string, origin string) []string { - var links []string - doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) - if err != nil { - return nil - } - - originurl, err := url.Parse(origin) - if err != nil { - return nil - } - - uniqueLinks := make(map[string]bool) - doc.Find("a").Each(func(i int, s *goquery.Selection) { - link, _ := s.Attr("href") - - parsedLink, err := originurl.Parse(link) - - if err != nil || !isValidLink(parsedLink) { - return - } - - absLink := parsedLink.String() - - if !uniqueLinks[absLink] { - links = append(links, absLink) - uniqueLinks[absLink] = true - } - }) - - return links -} - -func isValidLink(link *gourl.URL) bool { - if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" { - return false - } +func defaultHeaders() http.Header { + h := http.Header{} + h.Set("User-Agent", "flyscrape/0.1") - return true + return h } @@ -7,10 +7,11 @@ package flyscrape import ( "bytes" "encoding/json" + "net/http" "strings" ) -func PrettyPrint(v any, prefix string) string { +func Prettify(v any, prefix string) string { var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.SetEscapeHTML(false) @@ -19,14 +20,8 @@ func PrettyPrint(v any, prefix string) string { return prefix + strings.TrimSuffix(buf.String(), "\n") } -func Print(v any, prefix string) string { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - enc.Encode(v) - return prefix + strings.TrimSuffix(buf.String(), "\n") -} +type RoundTripFunc func(*http.Request) (*http.Response, error) -func ParseConfig(cfg Config, v any) { - json.Unmarshal(cfg, v) +func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) } |