summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Tanlak <philipp.tanlak@gmail.com>2023-10-05 14:53:37 +0200
committerPhilipp Tanlak <philipp.tanlak@gmail.com>2023-10-05 14:53:37 +0200
commit1fc497fbdc79a43c62ac2e8eaf4827752dbeef8e (patch)
tree67738e213ef97f249bdfa0f1bddda0839192cb77
parentbd9e7f7acfd855d4685aa4544169c0e29cdbf205 (diff)
Refactor codebase into modules
-rw-r--r--cmd/flyscrape/dev.go49
-rw-r--r--cmd/flyscrape/main.go5
-rw-r--r--cmd/flyscrape/run.go9
-rw-r--r--fetch.go89
-rw-r--r--fetch_test.go68
-rw-r--r--go.mod2
-rw-r--r--js.go7
-rw-r--r--mock.go6
-rw-r--r--module.go91
-rw-r--r--modules/cache/cache.go78
-rw-r--r--modules/cache/cache_test.go38
-rw-r--r--modules/depth/depth.go15
-rw-r--r--modules/depth/depth_test.go40
-rw-r--r--modules/domainfilter/domainfilter.go32
-rw-r--r--modules/domainfilter/domainfilter_test.go69
-rw-r--r--modules/followlinks/followlinks.go60
-rw-r--r--modules/followlinks/followlinks_test.go26
-rw-r--r--modules/hook/hook.go78
-rw-r--r--modules/jsonprint/jsonprint.go (renamed from modules/jsonprinter/jsonprinter.go)26
-rw-r--r--modules/ratelimit/ratelimit.go35
-rw-r--r--modules/ratelimit/ratelimit_test.go18
-rw-r--r--modules/starturl/starturl.go22
-rw-r--r--modules/starturl/starturl_test.go20
-rw-r--r--modules/urlfilter/urlfilter.go35
-rw-r--r--modules/urlfilter/urlfilter_test.go55
-rw-r--r--scrape.go215
-rw-r--r--utils.go15
27 files changed, 723 insertions, 480 deletions
diff --git a/cmd/flyscrape/dev.go b/cmd/flyscrape/dev.go
index 95c627e..fba3fba 100644
--- a/cmd/flyscrape/dev.go
+++ b/cmd/flyscrape/dev.go
@@ -17,7 +17,6 @@ type DevCommand struct{}
func (c *DevCommand) Run(args []string) error {
fs := flag.NewFlagSet("flyscrape-dev", flag.ContinueOnError)
- proxy := fs.String("proxy", "", "proxy")
fs.Usage = c.Usage
if err := fs.Parse(args); err != nil {
@@ -28,50 +27,25 @@ func (c *DevCommand) Run(args []string) error {
return fmt.Errorf("too many arguments")
}
- var fetch flyscrape.FetchFunc
- if *proxy != "" {
- fetch = flyscrape.ProxiedFetch(*proxy)
- } else {
- fetch = flyscrape.Fetch()
- }
-
- fetch = flyscrape.CachedFetch(fetch)
script := fs.Arg(0)
err := flyscrape.Watch(script, func(s string) error {
cfg, scrape, err := flyscrape.Compile(s)
if err != nil {
- screen.Clear()
- screen.MoveTopLeft()
-
- if errs, ok := err.(interface{ Unwrap() []error }); ok {
- for _, err := range errs.Unwrap() {
- log.Printf("%s:%v\n", script, err)
- }
- } else {
- log.Println(err)
- }
-
- // ignore compilation errors
+ printCompileErr(script, err)
return nil
}
scraper := flyscrape.NewScraper()
scraper.ScrapeFunc = scrape
+
flyscrape.LoadModules(scraper, cfg)
+ scraper.DisableModule("followlinks")
+ screen.Clear()
+ screen.MoveTopLeft()
scraper.Run()
- scraper.OnResponse(func(resp *flyscrape.Response) {
- screen.Clear()
- screen.MoveTopLeft()
- if resp.Error != nil {
- log.Println(resp.Error)
- return
- }
- fmt.Println(flyscrape.PrettyPrint(resp.Data, ""))
- })
-
return nil
})
if err != nil && err != flyscrape.StopWatch {
@@ -97,3 +71,16 @@ Examples:
$ flyscrape dev example.js
`[1:])
}
+
+func printCompileErr(script string, err error) {
+ screen.Clear()
+ screen.MoveTopLeft()
+
+ if errs, ok := err.(interface{ Unwrap() []error }); ok {
+ for _, err := range errs.Unwrap() {
+ log.Printf("%s:%v\n", script, err)
+ }
+ } else {
+ log.Println(err)
+ }
+}
diff --git a/cmd/flyscrape/main.go b/cmd/flyscrape/main.go
index bac411e..81d0e2b 100644
--- a/cmd/flyscrape/main.go
+++ b/cmd/flyscrape/main.go
@@ -12,10 +12,11 @@ import (
"os"
"strings"
+ _ "github.com/philippta/flyscrape/modules/cache"
_ "github.com/philippta/flyscrape/modules/depth"
_ "github.com/philippta/flyscrape/modules/domainfilter"
_ "github.com/philippta/flyscrape/modules/followlinks"
- _ "github.com/philippta/flyscrape/modules/jsonprinter"
+ _ "github.com/philippta/flyscrape/modules/jsonprint"
_ "github.com/philippta/flyscrape/modules/ratelimit"
_ "github.com/philippta/flyscrape/modules/starturl"
_ "github.com/philippta/flyscrape/modules/urlfilter"
@@ -66,7 +67,7 @@ Usage:
flyscrape <command> [arguments]
Commands:
-
+
new creates a sample scraping script
run runs a scraping script
dev watches and re-runs a scraping script
diff --git a/cmd/flyscrape/run.go b/cmd/flyscrape/run.go
index 4580e6d..b467abe 100644
--- a/cmd/flyscrape/run.go
+++ b/cmd/flyscrape/run.go
@@ -12,6 +12,7 @@ import (
"time"
"github.com/philippta/flyscrape"
+ "github.com/philippta/flyscrape/modules/hook"
)
type RunCommand struct{}
@@ -41,14 +42,18 @@ func (c *RunCommand) Run(args []string) error {
scraper := flyscrape.NewScraper()
scraper.ScrapeFunc = scrape
+
flyscrape.LoadModules(scraper, cfg)
count := 0
start := time.Now()
- scraper.OnResponse(func(resp *flyscrape.Response) {
- count++
+ scraper.LoadModule(hook.Module{
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ count++
+ },
})
+
scraper.Run()
log.Printf("Scraped %d websites in %v\n", count, time.Since(start))
diff --git a/fetch.go b/fetch.go
deleted file mode 100644
index d969a74..0000000
--- a/fetch.go
+++ /dev/null
@@ -1,89 +0,0 @@
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package flyscrape
-
-import (
- "crypto/tls"
- "io"
- "net/http"
- "net/url"
-
- "github.com/cornelk/hashmap"
-)
-
-const userAgent = "flyscrape/0.1"
-
-func ProxiedFetch(proxyURL string) FetchFunc {
- pu, err := url.Parse(proxyURL)
- if err != nil {
- panic("invalid proxy url")
- }
-
- client := http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyURL(pu),
- TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
- },
- }
-
- return func(url string) (string, error) {
- resp, err := client.Get(url)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", err
- }
-
- html := string(body)
- return html, nil
- }
-}
-
-func CachedFetch(fetch FetchFunc) FetchFunc {
- cache := hashmap.New[string, string]()
-
- return func(url string) (string, error) {
- if html, ok := cache.Get(url); ok {
- return html, nil
- }
-
- html, err := fetch(url)
- if err != nil {
- return "", err
- }
-
- cache.Set(url, html)
- return html, nil
- }
-}
-
-func Fetch() FetchFunc {
- return func(url string) (string, error) {
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return "", err
- }
-
- req.Header.Set("User-Agent", userAgent)
-
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", err
- }
-
- html := string(body)
- return html, nil
- }
-}
diff --git a/fetch_test.go b/fetch_test.go
deleted file mode 100644
index b32ac0f..0000000
--- a/fetch_test.go
+++ /dev/null
@@ -1,68 +0,0 @@
-package flyscrape_test
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/philippta/flyscrape"
- "github.com/stretchr/testify/require"
-)
-
-func TestFetchFetch(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("foobar"))
- }))
-
- fetch := flyscrape.Fetch()
-
- html, err := fetch(srv.URL)
- require.NoError(t, err)
- require.Equal(t, html, "foobar")
-}
-
-func TestFetchCachedFetch(t *testing.T) {
- numcalled := 0
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- numcalled++
- w.Write([]byte("foobar"))
- }))
-
- fetch := flyscrape.CachedFetch(flyscrape.Fetch())
-
- html, err := fetch(srv.URL)
- require.NoError(t, err)
- require.Equal(t, html, "foobar")
-
- html, err = fetch(srv.URL)
- require.NoError(t, err)
- require.Equal(t, html, "foobar")
-
- require.Equal(t, 1, numcalled)
-}
-
-func TestFetchProxiedFetch(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, r.URL.String(), "http://example.com/foo")
- w.Write([]byte("foobar"))
- }))
-
- fetch := flyscrape.ProxiedFetch(srv.URL)
-
- html, err := fetch("http://example.com/foo")
- require.NoError(t, err)
- require.Equal(t, html, "foobar")
-}
-
-func TestFetchUserAgent(t *testing.T) {
- var userAgent string
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- userAgent = r.Header.Get("User-Agent")
- }))
-
- fetch := flyscrape.Fetch()
-
- _, err := fetch(srv.URL)
- require.NoError(t, err)
- require.Equal(t, "flyscrape/0.1", userAgent)
-}
diff --git a/go.mod b/go.mod
index cf7a7a5..4341386 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/philippta/flyscrape
-go 1.20
+go 1.21
require (
github.com/PuerkitoBio/goquery v1.8.1
diff --git a/js.go b/js.go
index 5343754..7e1f7c0 100644
--- a/js.go
+++ b/js.go
@@ -18,6 +18,13 @@ import (
type Config []byte
+type ScrapeParams struct {
+ HTML string
+ URL string
+}
+
+type ScrapeFunc func(ScrapeParams) (any, error)
+
type TransformError struct {
Line int
Column int
diff --git a/mock.go b/mock.go
index 44b8837..147ca82 100644
--- a/mock.go
+++ b/mock.go
@@ -1,3 +1,7 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
package flyscrape
import (
@@ -7,7 +11,7 @@ import (
"strings"
)
-func MockTransport(statusCode int, html string) func(*http.Request) (*http.Response, error) {
+func MockTransport(statusCode int, html string) RoundTripFunc {
return func(*http.Request) (*http.Response, error) {
return MockResponse(statusCode, html)
}
diff --git a/module.go b/module.go
index 1839b76..0465808 100644
--- a/module.go
+++ b/module.go
@@ -1,44 +1,101 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
package flyscrape
import (
"encoding/json"
"net/http"
+ "sync"
)
-type Module any
+type Module interface {
+ ModuleInfo() ModuleInfo
+}
+
+type ModuleInfo struct {
+ ID string
+ New func() Module
+}
-type Transport interface {
- Transport(*http.Request) (*http.Response, error)
+type TransportAdapter interface {
+ AdaptTransport(http.RoundTripper) http.RoundTripper
}
-type CanRequest interface {
- CanRequest(url string, depth int) bool
+type RequestValidator interface {
+ ValidateRequest(*Request) bool
}
-type OnRequest interface {
- OnRequest(*Request)
+type RequestBuilder interface {
+ BuildRequest(*Request)
}
-type OnResponse interface {
- OnResponse(*Response)
+
+type ResponseReceiver interface {
+ ReceiveResponse(*Response)
}
-type OnLoad interface {
- OnLoad(Visitor)
+type Provisioner interface {
+ Provision(Context)
}
-type OnComplete interface {
- OnComplete()
+type Finalizer interface {
+ Finalize()
}
func RegisterModule(mod Module) {
- globalModules = append(globalModules, mod)
+ modulesMu.Lock()
+ defer modulesMu.Unlock()
+
+ id := mod.ModuleInfo().ID
+ if _, ok := modules[id]; ok {
+ panic("module with id: " + id + " already registered")
+ }
+ modules[mod.ModuleInfo().ID] = mod
}
func LoadModules(s *Scraper, cfg Config) {
- for _, mod := range globalModules {
- json.Unmarshal(cfg, mod)
+ modulesMu.RLock()
+ defer modulesMu.RUnlock()
+
+ loaded := map[string]struct{}{}
+
+ // load standard modules in order
+ for _, id := range moduleOrder {
+ mod := modules[id].ModuleInfo().New()
+ if err := json.Unmarshal(cfg, mod); err != nil {
+ panic("failed to decode config: " + err.Error())
+ }
+ s.LoadModule(mod)
+ loaded[id] = struct{}{}
+ }
+
+ // load custom modules
+ for id := range modules {
+ if _, ok := loaded[id]; ok {
+ continue
+ }
+ mod := modules[id].ModuleInfo().New()
+ if err := json.Unmarshal(cfg, mod); err != nil {
+ panic("failed to decode config: " + err.Error())
+ }
s.LoadModule(mod)
+ loaded[id] = struct{}{}
}
}
-var globalModules = []Module{}
+var (
+ modules = map[string]Module{}
+ modulesMu sync.RWMutex
+
+ moduleOrder = []string{
+ "cache",
+ "starturl",
+ "followlinks",
+ "depth",
+ "domainfilter",
+ "urlfilter",
+ "ratelimit",
+ "jsonprint",
+ }
+)
diff --git a/modules/cache/cache.go b/modules/cache/cache.go
new file mode 100644
index 0000000..1a321be
--- /dev/null
+++ b/modules/cache/cache.go
@@ -0,0 +1,78 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cache
+
+import (
+ "bufio"
+ "bytes"
+ "net/http"
+ "net/http/httputil"
+
+ "github.com/cornelk/hashmap"
+ "github.com/philippta/flyscrape"
+)
+
+func init() {
+ flyscrape.RegisterModule(Module{})
+}
+
+type Module struct {
+ Cache string `json:"cache"`
+
+ cache *hashmap.Map[string, []byte]
+}
+
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "cache",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(flyscrape.Context) {
+ if m.disabled() {
+ return
+ }
+ if m.cache == nil {
+ m.cache = hashmap.New[string, []byte]()
+ }
+}
+
+func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
+ if m.disabled() {
+ return t
+ }
+
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ key := cacheKey(r)
+
+ if b, ok := m.cache.Get(key); ok {
+ if resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(b)), r); err == nil {
+ return resp, nil
+ }
+ }
+
+ resp, err := t.RoundTrip(r)
+ if err != nil {
+ return resp, err
+ }
+
+ encoded, err := httputil.DumpResponse(resp, true)
+ if err != nil {
+ return resp, err
+ }
+
+ m.cache.Set(key, encoded)
+ return resp, nil
+ })
+}
+
+func (m *Module) disabled() bool {
+ return m.Cache == ""
+}
+
+func cacheKey(r *http.Request) string {
+ return r.Method + " " + r.URL.String()
+}
diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go
new file mode 100644
index 0000000..4565e00
--- /dev/null
+++ b/modules/cache/cache_test.go
@@ -0,0 +1,38 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cache_test
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/philippta/flyscrape"
+ "github.com/philippta/flyscrape/modules/cache"
+ "github.com/philippta/flyscrape/modules/hook"
+ "github.com/philippta/flyscrape/modules/starturl"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCache(t *testing.T) {
+ cachemod := &cache.Module{Cache: "memory"}
+ calls := 0
+
+ for i := 0; i < 2; i++ {
+ scraper := flyscrape.NewScraper()
+ scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ calls++
+ return flyscrape.MockResponse(200, "foo")
+ })
+ },
+ })
+ scraper.LoadModule(cachemod)
+ scraper.Run()
+ }
+
+ require.Equal(t, 1, calls)
+}
diff --git a/modules/depth/depth.go b/modules/depth/depth.go
index 0cfbc71..866f5ae 100644
--- a/modules/depth/depth.go
+++ b/modules/depth/depth.go
@@ -9,15 +9,22 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
Depth int `json:"depth"`
}
-func (m *Module) CanRequest(url string, depth int) bool {
- return depth <= m.Depth
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "depth",
+ New: func() flyscrape.Module { return new(Module) },
+ }
}
-var _ flyscrape.CanRequest = (*Module)(nil)
+func (m *Module) ValidateRequest(r *flyscrape.Request) bool {
+ return r.Depth <= m.Depth
+}
+
+var _ flyscrape.RequestValidator = (*Module)(nil)
diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go
index c9afd6f..10b67e9 100644
--- a/modules/depth/depth_test.go
+++ b/modules/depth/depth_test.go
@@ -6,36 +6,44 @@ package depth_test
import (
"net/http"
+ "sync"
"testing"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/depth"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestDepth(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
scraper.LoadModule(&depth.Module{Depth: 2})
-
- scraper.SetTransport(func(r *http.Request) (*http.Response, error) {
- switch r.URL.String() {
- case "http://www.example.com":
- return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`)
- case "http://www.google.com":
- return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
- case "http://www.duckduckgo.com":
- return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`)
- }
- return flyscrape.MockResponse(200, "")
- })
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ switch r.URL.String() {
+ case "http://www.example.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`)
+ case "http://www.google.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ case "http://www.duckduckgo.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`)
+ }
+ return flyscrape.MockResponse(200, "")
+ })
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
diff --git a/modules/domainfilter/domainfilter.go b/modules/domainfilter/domainfilter.go
index ba9ebe6..e8691d3 100644
--- a/modules/domainfilter/domainfilter.go
+++ b/modules/domainfilter/domainfilter.go
@@ -10,23 +10,39 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
URL string `json:"url"`
AllowedDomains []string `json:"allowedDomains"`
BlockedDomains []string `json:"blockedDomains"`
+
+ active bool
}
-func (m *Module) OnLoad(v flyscrape.Visitor) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "domainfilter",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(v flyscrape.Context) {
+ if m.URL == "" {
+ return
+ }
if u, err := url.Parse(m.URL); err == nil {
m.AllowedDomains = append(m.AllowedDomains, u.Host())
}
}
-func (m *Module) CanRequest(rawurl string, depth int) bool {
- u, err := url.Parse(rawurl)
+func (m *Module) ValidateRequest(r *flyscrape.Request) bool {
+ if m.disabled() {
+ return true
+ }
+
+ u, err := url.Parse(r.URL)
if err != nil {
return false
}
@@ -51,7 +67,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool {
return ok
}
+func (m *Module) disabled() bool {
+ return len(m.AllowedDomains) == 0 && len(m.BlockedDomains) == 0
+}
+
var (
- _ flyscrape.CanRequest = (*Module)(nil)
- _ flyscrape.OnLoad = (*Module)(nil)
+ _ flyscrape.RequestValidator = (*Module)(nil)
+ _ flyscrape.Provisioner = (*Module)(nil)
)
diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go
index 884a89f..a1c8401 100644
--- a/modules/domainfilter/domainfilter_test.go
+++ b/modules/domainfilter/domainfilter_test.go
@@ -5,16 +5,22 @@
package domainfilter_test
import (
+ "net/http"
+ "sync"
"testing"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/domainfilter"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestDomainfilterAllowed(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -22,14 +28,17 @@ func TestDomainfilterAllowed(t *testing.T) {
URL: "http://www.example.com",
AllowedDomains: []string{"www.google.com"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
@@ -40,6 +49,9 @@ func TestDomainfilterAllowed(t *testing.T) {
}
func TestDomainfilterAllowedAll(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -47,14 +59,17 @@ func TestDomainfilterAllowedAll(t *testing.T) {
URL: "http://www.example.com",
AllowedDomains: []string{"*"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
@@ -66,6 +81,9 @@ func TestDomainfilterAllowedAll(t *testing.T) {
}
func TestDomainfilterBlocked(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -74,14 +92,17 @@ func TestDomainfilterBlocked(t *testing.T) {
AllowedDomains: []string{"*"},
BlockedDomains: []string{"www.google.com"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
diff --git a/modules/followlinks/followlinks.go b/modules/followlinks/followlinks.go
index 99d6cee..c53f167 100644
--- a/modules/followlinks/followlinks.go
+++ b/modules/followlinks/followlinks.go
@@ -5,19 +5,71 @@
package followlinks
import (
+ "net/url"
+ "strings"
+
+ "github.com/PuerkitoBio/goquery"
"github.com/philippta/flyscrape"
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct{}
-func (m *Module) OnResponse(resp *flyscrape.Response) {
- for _, link := range flyscrape.ParseLinks(string(resp.Body), resp.Request.URL) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "followlinks",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) ReceiveResponse(resp *flyscrape.Response) {
+ for _, link := range parseLinks(string(resp.Body), resp.Request.URL) {
resp.Visit(link)
}
}
-var _ flyscrape.OnResponse = (*Module)(nil)
+func parseLinks(html string, origin string) []string {
+ var links []string
+ doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
+ if err != nil {
+ return nil
+ }
+
+ originurl, err := url.Parse(origin)
+ if err != nil {
+ return nil
+ }
+
+ uniqueLinks := make(map[string]bool)
+ doc.Find("a").Each(func(i int, s *goquery.Selection) {
+ link, _ := s.Attr("href")
+
+ parsedLink, err := originurl.Parse(link)
+
+ if err != nil || !isValidLink(parsedLink) {
+ return
+ }
+
+ absLink := parsedLink.String()
+
+ if !uniqueLinks[absLink] {
+ links = append(links, absLink)
+ uniqueLinks[absLink] = true
+ }
+ })
+
+ return links
+}
+
+func isValidLink(link *url.URL) bool {
+ if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" {
+ return false
+ }
+
+ return true
+}
+
+var _ flyscrape.ResponseReceiver = (*Module)(nil)
diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go
index 18c8ceb..0a628c3 100644
--- a/modules/followlinks/followlinks_test.go
+++ b/modules/followlinks/followlinks_test.go
@@ -5,27 +5,37 @@
package followlinks_test
import (
+ "net/http"
+ "sync"
"testing"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestFollowLinks(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
scraper.LoadModule(&followlinks.Module{})
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="/baz">Baz</a>
- <a href="baz">Baz</a>
- <a href="http://www.google.com">Google</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="/baz">Baz</a>
+ <a href="baz">Baz</a>
+ <a href="http://www.google.com">Google</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
diff --git a/modules/hook/hook.go b/modules/hook/hook.go
new file mode 100644
index 0000000..4484f47
--- /dev/null
+++ b/modules/hook/hook.go
@@ -0,0 +1,78 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package hook
+
+import (
+ "net/http"
+
+ "github.com/philippta/flyscrape"
+)
+
+type Module struct {
+ AdaptTransportFn func(http.RoundTripper) http.RoundTripper
+ ValidateRequestFn func(*flyscrape.Request) bool
+ BuildRequestFn func(*flyscrape.Request)
+ ReceiveResponseFn func(*flyscrape.Response)
+ ProvisionFn func(flyscrape.Context)
+ FinalizeFn func()
+}
+
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "hook",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
+ if m.AdaptTransportFn == nil {
+ return t
+ }
+ return m.AdaptTransportFn(t)
+}
+
+func (m Module) ValidateRequest(r *flyscrape.Request) bool {
+ if m.ValidateRequestFn == nil {
+ return true
+ }
+ return m.ValidateRequestFn(r)
+}
+
+func (m Module) BuildRequest(r *flyscrape.Request) {
+ if m.BuildRequestFn == nil {
+ return
+ }
+ m.BuildRequestFn(r)
+}
+
+func (m Module) ReceiveResponse(r *flyscrape.Response) {
+ if m.ReceiveResponseFn == nil {
+ return
+ }
+ m.ReceiveResponseFn(r)
+}
+
+func (m Module) Provision(ctx flyscrape.Context) {
+ if m.ProvisionFn == nil {
+ return
+ }
+ m.ProvisionFn(ctx)
+}
+
+func (m Module) Finalize() {
+ if m.FinalizeFn == nil {
+ return
+ }
+ m.FinalizeFn()
+}
+
+var (
+ _ flyscrape.TransportAdapter = Module{}
+ _ flyscrape.RequestValidator = Module{}
+ _ flyscrape.RequestBuilder = Module{}
+ _ flyscrape.ResponseReceiver = Module{}
+ _ flyscrape.Provisioner = Module{}
+ _ flyscrape.Finalizer = Module{}
+)
diff --git a/modules/jsonprinter/jsonprinter.go b/modules/jsonprint/jsonprint.go
index 3026f29..29d3375 100644
--- a/modules/jsonprinter/jsonprinter.go
+++ b/modules/jsonprint/jsonprint.go
@@ -2,7 +2,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-package jsonprinter
+package jsonprint
import (
"fmt"
@@ -12,20 +12,28 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
- first bool
+ once bool
}
-func (m *Module) OnResponse(resp *flyscrape.Response) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "jsonprint",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) ReceiveResponse(resp *flyscrape.Response) {
if resp.Error == nil && resp.Data == nil {
return
}
- if m.first {
+ if !m.once {
fmt.Println("[")
+ m.once = true
} else {
fmt.Println(",")
}
@@ -37,10 +45,10 @@ func (m *Module) OnResponse(resp *flyscrape.Response) {
Timestamp: time.Now(),
}
- fmt.Print(flyscrape.PrettyPrint(o, " "))
+ fmt.Print(flyscrape.Prettify(o, " "))
}
-func (m *Module) OnComplete() {
+func (m *Module) Finalize() {
fmt.Println("\n]")
}
@@ -52,6 +60,6 @@ type output struct {
}
var (
- _ flyscrape.OnResponse = (*Module)(nil)
- _ flyscrape.OnComplete = (*Module)(nil)
+ _ flyscrape.ResponseReceiver = (*Module)(nil)
+ _ flyscrape.Finalizer = (*Module)(nil)
)
diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go
index be622f6..9588db3 100644
--- a/modules/ratelimit/ratelimit.go
+++ b/modules/ratelimit/ratelimit.go
@@ -11,7 +11,7 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
@@ -21,7 +21,18 @@ type Module struct {
semaphore chan struct{}
}
-func (m *Module) OnLoad(v flyscrape.Visitor) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "ratelimit",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(v flyscrape.Context) {
+ if m.disabled() {
+ return
+ }
+
rate := time.Duration(float64(time.Second) / m.Rate)
m.ticker = time.NewTicker(rate)
@@ -34,16 +45,26 @@ func (m *Module) OnLoad(v flyscrape.Visitor) {
}()
}
-func (m *Module) OnRequest(_ *flyscrape.Request) {
+func (m *Module) BuildRequest(_ *flyscrape.Request) {
+ if m.disabled() {
+ return
+ }
<-m.semaphore
}
-func (m *Module) OnComplete() {
+func (m *Module) Finalize() {
+ if m.disabled() {
+ return
+ }
m.ticker.Stop()
}
+func (m *Module) disabled() bool {
+ return m.Rate == 0
+}
+
var (
- _ flyscrape.OnRequest = (*Module)(nil)
- _ flyscrape.OnLoad = (*Module)(nil)
- _ flyscrape.OnComplete = (*Module)(nil)
+ _ flyscrape.RequestBuilder = (*Module)(nil)
+ _ flyscrape.Provisioner = (*Module)(nil)
+ _ flyscrape.Finalizer = (*Module)(nil)
)
diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go
index 5e91f8f..ffd061c 100644
--- a/modules/ratelimit/ratelimit_test.go
+++ b/modules/ratelimit/ratelimit_test.go
@@ -5,33 +5,37 @@
package ratelimit_test
import (
+ "net/http"
"testing"
"time"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/ratelimit"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestRatelimit(t *testing.T) {
+ var times []time.Time
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
scraper.LoadModule(&ratelimit.Module{
Rate: 100,
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `<a href="foo">foo</a>`))
-
- var times []time.Time
- scraper.OnRequest(func(req *flyscrape.Request) {
- times = append(times, time.Now())
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `<a href="foo">foo</a>`)
+ },
+ BuildRequestFn: func(r *flyscrape.Request) {
+ times = append(times, time.Now())
+ },
})
start := time.Now()
-
scraper.Run()
first := times[0].Add(-10 * time.Millisecond)
diff --git a/modules/starturl/starturl.go b/modules/starturl/starturl.go
index 109d28f..9e3ec31 100644
--- a/modules/starturl/starturl.go
+++ b/modules/starturl/starturl.go
@@ -9,15 +9,29 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
URL string `json:"url"`
}
-func (m *Module) OnLoad(v flyscrape.Visitor) {
- v.Visit(m.URL)
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "starturl",
+ New: func() flyscrape.Module { return new(Module) },
+ }
}
-var _ flyscrape.OnLoad = (*Module)(nil)
+func (m *Module) Provision(ctx flyscrape.Context) {
+ if m.disabled() {
+ return
+ }
+ ctx.Visit(m.URL)
+}
+
+func (m *Module) disabled() bool {
+ return m.URL == ""
+}
+
+var _ flyscrape.Provisioner = (*Module)(nil)
diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go
index 6fab776..86e4ad7 100644
--- a/modules/starturl/starturl_test.go
+++ b/modules/starturl/starturl_test.go
@@ -5,23 +5,29 @@
package starturl_test
import (
+ "net/http"
"testing"
"github.com/philippta/flyscrape"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestStartURL(t *testing.T) {
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.SetTransport(flyscrape.MockTransport(200, ""))
-
var url string
var depth int
- scraper.OnRequest(func(req *flyscrape.Request) {
- url = req.URL
- depth = req.Depth
+
+ scraper := flyscrape.NewScraper()
+ scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, "")
+ },
+ BuildRequestFn: func(r *flyscrape.Request) {
+ url = r.URL
+ depth = r.Depth
+ },
})
scraper.Run()
diff --git a/modules/urlfilter/urlfilter.go b/modules/urlfilter/urlfilter.go
index 00a4bd2..1297c35 100644
--- a/modules/urlfilter/urlfilter.go
+++ b/modules/urlfilter/urlfilter.go
@@ -11,7 +11,7 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
@@ -23,7 +23,18 @@ type Module struct {
blockedURLsRE []*regexp.Regexp
}
-func (m *Module) OnLoad(v flyscrape.Visitor) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "urlfilter",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(v flyscrape.Context) {
+ if m.disabled() {
+ return
+ }
+
for _, pat := range m.AllowedURLs {
re, err := regexp.Compile(pat)
if err != nil {
@@ -41,9 +52,13 @@ func (m *Module) OnLoad(v flyscrape.Visitor) {
}
}
-func (m *Module) CanRequest(rawurl string, depth int) bool {
+func (m *Module) ValidateRequest(r *flyscrape.Request) bool {
+ if m.disabled() {
+ return true
+ }
+
// allow root url
- if rawurl == m.URL {
+ if r.URL == m.URL {
return true
}
@@ -58,14 +73,14 @@ func (m *Module) CanRequest(rawurl string, depth int) bool {
}
for _, re := range m.allowedURLsRE {
- if re.MatchString(rawurl) {
+ if re.MatchString(r.URL) {
ok = true
break
}
}
for _, re := range m.blockedURLsRE {
- if re.MatchString(rawurl) {
+ if re.MatchString(r.URL) {
ok = false
break
}
@@ -74,7 +89,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool {
return ok
}
+func (m *Module) disabled() bool {
+ return len(m.AllowedURLs) == 0 && len(m.BlockedURLs) == 0
+}
+
var (
- _ flyscrape.CanRequest = (*Module)(nil)
- _ flyscrape.OnLoad = (*Module)(nil)
+ _ flyscrape.RequestValidator = (*Module)(nil)
+ _ flyscrape.Provisioner = (*Module)(nil)
)
diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go
index e383a32..9ebb8a5 100644
--- a/modules/urlfilter/urlfilter_test.go
+++ b/modules/urlfilter/urlfilter_test.go
@@ -5,16 +5,22 @@
package urlfilter_test
import (
+ "net/http"
+ "sync"
"testing"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/philippta/flyscrape/modules/urlfilter"
"github.com/stretchr/testify/require"
)
func TestURLFilterAllowed(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"})
scraper.LoadModule(&followlinks.Module{})
@@ -22,16 +28,19 @@ func TestURLFilterAllowed(t *testing.T) {
URL: "http://www.example.com/",
AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="foo?id=123">123</a>
- <a href="foo?id=ABC">ABC</a>
- <a href="/bar">bar</a>
- <a href="/barz">barz</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="foo?id=123">123</a>
+ <a href="foo?id=ABC">ABC</a>
+ <a href="/bar">bar</a>
+ <a href="/barz">barz</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
@@ -43,6 +52,9 @@ func TestURLFilterAllowed(t *testing.T) {
}
func TestURLFilterBlocked(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"})
scraper.LoadModule(&followlinks.Module{})
@@ -50,16 +62,19 @@ func TestURLFilterBlocked(t *testing.T) {
URL: "http://www.example.com/",
BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="foo?id=123">123</a>
- <a href="foo?id=ABC">ABC</a>
- <a href="/bar">bar</a>
- <a href="/barz">barz</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="foo?id=123">123</a>
+ <a href="foo?id=ABC">ABC</a>
+ <a href="/bar">bar</a>
+ <a href="/barz">barz</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
diff --git a/scrape.go b/scrape.go
index 4186247..c1257a9 100644
--- a/scrape.go
+++ b/scrape.go
@@ -9,28 +9,19 @@ import (
"log"
"net/http"
"net/http/cookiejar"
- "net/url"
- "strings"
+ "slices"
"sync"
- gourl "net/url"
-
- "github.com/PuerkitoBio/goquery"
"github.com/cornelk/hashmap"
)
-type ScrapeParams struct {
- HTML string
- URL string
-}
-
-type ScrapeFunc func(ScrapeParams) (any, error)
-
type FetchFunc func(url string) (string, error)
-type Visitor interface {
+type Context interface {
Visit(url string)
MarkVisited(url string)
+ MarkUnvisited(url string)
+ DisableModule(id string)
}
type Request struct {
@@ -57,62 +48,39 @@ type target struct {
depth int
}
-type Scraper struct {
- ScrapeFunc ScrapeFunc
-
- cfg Config
- wg sync.WaitGroup
- jobs chan target
- visited *hashmap.Map[string, struct{}]
- modules *hashmap.Map[string, Module]
- cookieJar *cookiejar.Jar
-
- canRequestHandlers []func(url string, depth int) bool
- onRequestHandlers []func(*Request)
- onResponseHandlers []func(*Response)
- onCompleteHandlers []func()
- transport func(*http.Request) (*http.Response, error)
-}
-
func NewScraper() *Scraper {
- jar, _ := cookiejar.New(nil)
- s := &Scraper{
+ return &Scraper{
jobs: make(chan target, 1024),
visited: hashmap.New[string, struct{}](),
- modules: hashmap.New[string, Module](),
- transport: func(r *http.Request) (*http.Response, error) {
- r.Header.Set("User-Agent", "flyscrape/0.1")
- return http.DefaultClient.Do(r)
- },
- cookieJar: jar,
}
- return s
}
-func (s *Scraper) LoadModule(mod Module) {
- if v, ok := mod.(Transport); ok {
- s.SetTransport(v.Transport)
- }
+type Scraper struct {
+ ScrapeFunc ScrapeFunc
- if v, ok := mod.(CanRequest); ok {
- s.CanRequest(v.CanRequest)
- }
+ wg sync.WaitGroup
+ jobs chan target
+ visited *hashmap.Map[string, struct{}]
- if v, ok := mod.(OnRequest); ok {
- s.OnRequest(v.OnRequest)
- }
+ modules []Module
+ moduleIDs []string
+ client *http.Client
+}
- if v, ok := mod.(OnResponse); ok {
- s.OnResponse(v.OnResponse)
- }
+func (s *Scraper) LoadModule(mod Module) {
+ id := mod.ModuleInfo().ID
- if v, ok := mod.(OnLoad); ok {
- v.OnLoad(s)
- }
+ s.modules = append(s.modules, mod)
+ s.moduleIDs = append(s.moduleIDs, id)
+}
- if v, ok := mod.(OnComplete); ok {
- s.OnComplete(v.OnComplete)
+func (s *Scraper) DisableModule(id string) {
+ idx := slices.Index(s.moduleIDs, id)
+ if idx == -1 {
+ return
}
+ s.modules = slices.Delete(s.modules, idx, idx+1)
+ s.moduleIDs = slices.Delete(s.moduleIDs, idx, idx+1)
}
func (s *Scraper) Visit(url string) {
@@ -123,49 +91,47 @@ func (s *Scraper) MarkVisited(url string) {
s.visited.Insert(url, struct{}{})
}
-func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) {
- s.transport = f
-}
-
-func (s *Scraper) CanRequest(f func(url string, depth int) bool) {
- s.canRequestHandlers = append(s.canRequestHandlers, f)
-}
-
-func (s *Scraper) OnRequest(f func(req *Request)) {
- s.onRequestHandlers = append(s.onRequestHandlers, f)
-}
-
-func (s *Scraper) OnResponse(f func(resp *Response)) {
- s.onResponseHandlers = append(s.onResponseHandlers, f)
-}
-
-func (s *Scraper) OnComplete(f func()) {
- s.onCompleteHandlers = append(s.onCompleteHandlers, f)
+func (s *Scraper) MarkUnvisited(url string) {
+ s.visited.Del(url)
}
func (s *Scraper) Run() {
- go s.worker()
+ for _, mod := range s.modules {
+ if v, ok := mod.(Provisioner); ok {
+ v.Provision(s)
+ }
+ }
+
+ s.initClient()
+ go s.scrape()
s.wg.Wait()
close(s.jobs)
- for _, handler := range s.onCompleteHandlers {
- handler()
+ for _, mod := range s.modules {
+ if v, ok := mod.(Finalizer); ok {
+ v.Finalize()
+ }
}
}
-func (s *Scraper) worker() {
- for job := range s.jobs {
- go func(job target) {
- defer s.wg.Done()
+func (s *Scraper) initClient() {
+ jar, _ := cookiejar.New(nil)
+ s.client = &http.Client{Jar: jar}
- for _, handler := range s.canRequestHandlers {
- if !handler(job.url, job.depth) {
- return
- }
- }
+ for _, mod := range s.modules {
+ if v, ok := mod.(TransportAdapter); ok {
+ s.client.Transport = v.AdaptTransport(s.client.Transport)
+ }
+ }
+}
+func (s *Scraper) scrape() {
+ for job := range s.jobs {
+ job := job
+ go func() {
s.process(job.url, job.depth)
- }(job)
+ s.wg.Done()
+ }()
}
}
@@ -173,8 +139,9 @@ func (s *Scraper) process(url string, depth int) {
request := &Request{
Method: http.MethodGet,
URL: url,
- Headers: http.Header{},
- Cookies: s.cookieJar,
+ Headers: defaultHeaders(),
+ Cookies: s.client.Jar,
+ Depth: depth,
}
response := &Response{
@@ -184,11 +151,11 @@ func (s *Scraper) process(url string, depth int) {
},
}
- defer func() {
- for _, handler := range s.onResponseHandlers {
- handler(response)
+ for _, mod := range s.modules {
+ if v, ok := mod.(RequestBuilder); ok {
+ v.BuildRequest(request)
}
- }()
+ }
req, err := http.NewRequest(request.Method, request.URL, nil)
if err != nil {
@@ -197,11 +164,23 @@ func (s *Scraper) process(url string, depth int) {
}
req.Header = request.Headers
- for _, handler := range s.onRequestHandlers {
- handler(request)
+ for _, mod := range s.modules {
+ if v, ok := mod.(RequestValidator); ok {
+ if !v.ValidateRequest(request) {
+ return
+ }
+ }
}
- resp, err := s.transport(req)
+ defer func() {
+ for _, mod := range s.modules {
+ if v, ok := mod.(ResponseReceiver); ok {
+ v.ReceiveResponse(response)
+ }
+ }
+ }()
+
+ resp, err := s.client.Do(req)
if err != nil {
response.Error = err
return
@@ -241,43 +220,9 @@ func (s *Scraper) enqueueJob(url string, depth int) {
}
}
-func ParseLinks(html string, origin string) []string {
- var links []string
- doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
- if err != nil {
- return nil
- }
-
- originurl, err := url.Parse(origin)
- if err != nil {
- return nil
- }
-
- uniqueLinks := make(map[string]bool)
- doc.Find("a").Each(func(i int, s *goquery.Selection) {
- link, _ := s.Attr("href")
-
- parsedLink, err := originurl.Parse(link)
-
- if err != nil || !isValidLink(parsedLink) {
- return
- }
-
- absLink := parsedLink.String()
-
- if !uniqueLinks[absLink] {
- links = append(links, absLink)
- uniqueLinks[absLink] = true
- }
- })
-
- return links
-}
-
-func isValidLink(link *gourl.URL) bool {
- if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" {
- return false
- }
+func defaultHeaders() http.Header {
+ h := http.Header{}
+ h.Set("User-Agent", "flyscrape/0.1")
- return true
+ return h
}
diff --git a/utils.go b/utils.go
index 73efa4a..faa4937 100644
--- a/utils.go
+++ b/utils.go
@@ -7,10 +7,11 @@ package flyscrape
import (
"bytes"
"encoding/json"
+ "net/http"
"strings"
)
-func PrettyPrint(v any, prefix string) string {
+func Prettify(v any, prefix string) string {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
@@ -19,14 +20,8 @@ func PrettyPrint(v any, prefix string) string {
return prefix + strings.TrimSuffix(buf.String(), "\n")
}
-func Print(v any, prefix string) string {
- var buf bytes.Buffer
- enc := json.NewEncoder(&buf)
- enc.SetEscapeHTML(false)
- enc.Encode(v)
- return prefix + strings.TrimSuffix(buf.String(), "\n")
-}
+type RoundTripFunc func(*http.Request) (*http.Response, error)
-func ParseConfig(cfg Config, v any) {
- json.Unmarshal(cfg, v)
+func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
+ return f(r)
}