summaryrefslogtreecommitdiff
path: root/modules/domainfilter
diff options
context:
space:
mode:
Diffstat (limited to 'modules/domainfilter')
-rw-r--r--modules/domainfilter/domainfilter.go32
-rw-r--r--modules/domainfilter/domainfilter_test.go69
2 files changed, 71 insertions, 30 deletions
diff --git a/modules/domainfilter/domainfilter.go b/modules/domainfilter/domainfilter.go
index ba9ebe6..e8691d3 100644
--- a/modules/domainfilter/domainfilter.go
+++ b/modules/domainfilter/domainfilter.go
@@ -10,23 +10,39 @@ import (
)
func init() {
- flyscrape.RegisterModule(new(Module))
+ flyscrape.RegisterModule(Module{})
}
type Module struct {
URL string `json:"url"`
AllowedDomains []string `json:"allowedDomains"`
BlockedDomains []string `json:"blockedDomains"`
+
+ active bool
}
-func (m *Module) OnLoad(v flyscrape.Visitor) {
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "domainfilter",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(v flyscrape.Context) {
+ if m.URL == "" {
+ return
+ }
if u, err := url.Parse(m.URL); err == nil {
m.AllowedDomains = append(m.AllowedDomains, u.Host())
}
}
-func (m *Module) CanRequest(rawurl string, depth int) bool {
- u, err := url.Parse(rawurl)
+func (m *Module) ValidateRequest(r *flyscrape.Request) bool {
+ if m.disabled() {
+ return true
+ }
+
+ u, err := url.Parse(r.URL)
if err != nil {
return false
}
@@ -51,7 +67,11 @@ func (m *Module) CanRequest(rawurl string, depth int) bool {
return ok
}
+func (m *Module) disabled() bool {
+ return len(m.AllowedDomains) == 0 && len(m.BlockedDomains) == 0
+}
+
var (
- _ flyscrape.CanRequest = (*Module)(nil)
- _ flyscrape.OnLoad = (*Module)(nil)
+ _ flyscrape.RequestValidator = (*Module)(nil)
+ _ flyscrape.Provisioner = (*Module)(nil)
)
diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go
index 884a89f..a1c8401 100644
--- a/modules/domainfilter/domainfilter_test.go
+++ b/modules/domainfilter/domainfilter_test.go
@@ -5,16 +5,22 @@
package domainfilter_test
import (
+ "net/http"
+ "sync"
"testing"
"github.com/philippta/flyscrape"
"github.com/philippta/flyscrape/modules/domainfilter"
"github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
"github.com/philippta/flyscrape/modules/starturl"
"github.com/stretchr/testify/require"
)
func TestDomainfilterAllowed(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -22,14 +28,17 @@ func TestDomainfilterAllowed(t *testing.T) {
URL: "http://www.example.com",
AllowedDomains: []string{"www.google.com"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
@@ -40,6 +49,9 @@ func TestDomainfilterAllowed(t *testing.T) {
}
func TestDomainfilterAllowedAll(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -47,14 +59,17 @@ func TestDomainfilterAllowedAll(t *testing.T) {
URL: "http://www.example.com",
AllowedDomains: []string{"*"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()
@@ -66,6 +81,9 @@ func TestDomainfilterAllowedAll(t *testing.T) {
}
func TestDomainfilterBlocked(t *testing.T) {
+ var urls []string
+ var mu sync.Mutex
+
scraper := flyscrape.NewScraper()
scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
scraper.LoadModule(&followlinks.Module{})
@@ -74,14 +92,17 @@ func TestDomainfilterBlocked(t *testing.T) {
AllowedDomains: []string{"*"},
BlockedDomains: []string{"www.google.com"},
})
-
- scraper.SetTransport(flyscrape.MockTransport(200, `
- <a href="http://www.google.com">Google</a>
- <a href="http://www.duckduckgo.com">DuckDuckGo</a>`))
-
- var urls []string
- scraper.OnRequest(func(req *flyscrape.Request) {
- urls = append(urls, req.URL)
+ scraper.LoadModule(hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
+ <a href="http://www.google.com">Google</a>
+ <a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
})
scraper.Run()