summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/flyscrape/dev.go95
-rw-r--r--cmd/flyscrape/new.go4
-rw-r--r--cmd/flyscrape/run.go39
-rw-r--r--flyscrape.go131
-rw-r--r--js.go145
-rw-r--r--js_test.go162
-rw-r--r--mock.go26
-rw-r--r--module.go9
-rw-r--r--modules/depth/depth_test.go49
-rw-r--r--modules/domainfilter/domainfilter_test.go107
-rw-r--r--modules/followlinks/followlinks_test.go124
-rw-r--r--modules/proxy/proxy_test.go14
-rw-r--r--modules/ratelimit/ratelimit_test.go27
-rw-r--r--modules/starturl/starturl_test.go23
-rw-r--r--modules/urlfilter/urlfilter_test.go70
-rw-r--r--scrape.go70
-rw-r--r--utils.go16
17 files changed, 660 insertions, 451 deletions
diff --git a/cmd/flyscrape/dev.go b/cmd/flyscrape/dev.go
index 9ddb3bf..84a436b 100644
--- a/cmd/flyscrape/dev.go
+++ b/cmd/flyscrape/dev.go
@@ -5,16 +5,9 @@
package main
import (
- "encoding/json"
"flag"
"fmt"
- "log"
- "os"
- "os/signal"
- "path/filepath"
- "syscall"
- "github.com/inancgumus/screen"
"github.com/philippta/flyscrape"
)
@@ -27,46 +20,13 @@ func (c *DevCommand) Run(args []string) error {
if err := fs.Parse(args); err != nil {
return err
} else if fs.NArg() == 0 || fs.Arg(0) == "" {
- return fmt.Errorf("script path required")
+ c.Usage()
+ return flag.ErrHelp
} else if fs.NArg() > 1 {
return fmt.Errorf("too many arguments")
}
- script := fs.Arg(0)
- cachefile, err := newCacheFile()
- if err != nil {
- return fmt.Errorf("failed to create cache file: %w", err)
- }
-
- trapsignal(func() { os.RemoveAll(cachefile) })
-
- err = flyscrape.Watch(script, func(s string) error {
- cfg, scrape, err := flyscrape.Compile(s)
- if err != nil {
- printCompileErr(script, err)
- return nil
- }
-
- cfg = updateCfg(cfg, "depth", 0)
- cfg = updateCfg(cfg, "cache", "file:"+cachefile)
-
- scraper := flyscrape.NewScraper()
- scraper.ScrapeFunc = scrape
- scraper.Script = script
-
- flyscrape.LoadModules(scraper, cfg)
-
- screen.Clear()
- screen.MoveTopLeft()
- scraper.Run()
-
- return nil
- })
- if err != nil && err != flyscrape.StopWatch {
- return fmt.Errorf("failed to watch script %q: %w", script, err)
- }
-
- return nil
+ return flyscrape.Dev(fs.Arg(0))
}
func (c *DevCommand) Usage() {
@@ -78,58 +38,9 @@ Usage:
flyscrape dev SCRIPT
-
Examples:
# Run and watch script.
$ flyscrape dev example.js
`[1:])
}
-
-func printCompileErr(script string, err error) {
- screen.Clear()
- screen.MoveTopLeft()
-
- if errs, ok := err.(interface{ Unwrap() []error }); ok {
- for _, err := range errs.Unwrap() {
- log.Printf("%s:%v\n", script, err)
- }
- } else {
- log.Println(err)
- }
-}
-
-func updateCfg(cfg flyscrape.Config, key string, value any) flyscrape.Config {
- var m map[string]any
- if err := json.Unmarshal(cfg, &m); err != nil {
- return cfg
- }
-
- m[key] = value
-
- b, err := json.Marshal(m)
- if err != nil {
- return cfg
- }
-
- return b
-}
-
-func newCacheFile() (string, error) {
- cachedir, err := os.MkdirTemp("", "flyscrape-cache")
- if err != nil {
- return "", err
- }
- return filepath.Join(cachedir, "dev.cache"), nil
-}
-
-func trapsignal(f func()) {
- sig := make(chan os.Signal, 2)
- signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
-
- go func() {
- <-sig
- f()
- os.Exit(0)
- }()
-}
diff --git a/cmd/flyscrape/new.go b/cmd/flyscrape/new.go
index 8c0d6c4..4ab248e 100644
--- a/cmd/flyscrape/new.go
+++ b/cmd/flyscrape/new.go
@@ -21,7 +21,8 @@ func (c *NewCommand) Run(args []string) error {
if err := fs.Parse(args); err != nil {
return err
} else if fs.NArg() == 0 || fs.Arg(0) == "" {
- return fmt.Errorf("script path required")
+ c.Usage()
+ return flag.ErrHelp
} else if fs.NArg() > 1 {
return fmt.Errorf("too many arguments")
}
@@ -47,7 +48,6 @@ Usage:
flyscrape new SCRIPT
-
Examples:
# Create a new scraping script.
diff --git a/cmd/flyscrape/run.go b/cmd/flyscrape/run.go
index 039574b..7a8930a 100644
--- a/cmd/flyscrape/run.go
+++ b/cmd/flyscrape/run.go
@@ -7,12 +7,8 @@ package main
import (
"flag"
"fmt"
- "log"
- "os"
- "time"
"github.com/philippta/flyscrape"
- "github.com/philippta/flyscrape/modules/hook"
)
type RunCommand struct{}
@@ -24,41 +20,13 @@ func (c *RunCommand) Run(args []string) error {
if err := fs.Parse(args); err != nil {
return err
} else if fs.NArg() == 0 || fs.Arg(0) == "" {
- return fmt.Errorf("script path required")
+ c.Usage()
+ return flag.ErrHelp
} else if fs.NArg() > 1 {
return fmt.Errorf("too many arguments")
}
- script := fs.Arg(0)
- src, err := os.ReadFile(script)
- if err != nil {
- return fmt.Errorf("failed to read script %q: %w", script, err)
- }
-
- cfg, scrape, err := flyscrape.Compile(string(src))
- if err != nil {
- return fmt.Errorf("failed to compile script: %w", err)
- }
-
- scraper := flyscrape.NewScraper()
- scraper.ScrapeFunc = scrape
- scraper.Script = script
-
- flyscrape.LoadModules(scraper, cfg)
-
- count := 0
- start := time.Now()
-
- scraper.LoadModule(hook.Module{
- ReceiveResponseFn: func(r *flyscrape.Response) {
- count++
- },
- })
-
- scraper.Run()
-
- log.Printf("Scraped %d websites in %v\n", count, time.Since(start))
- return nil
+ return flyscrape.Run(fs.Arg(0))
}
func (c *RunCommand) Usage() {
@@ -69,7 +37,6 @@ Usage:
flyscrape run SCRIPT
-
Examples:
# Run the script.
diff --git a/flyscrape.go b/flyscrape.go
new file mode 100644
index 0000000..bb4ee30
--- /dev/null
+++ b/flyscrape.go
@@ -0,0 +1,131 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package flyscrape
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "syscall"
+
+ "github.com/inancgumus/screen"
+)
+
+func Run(file string) error {
+ src, err := os.ReadFile(file)
+ if err != nil {
+ return fmt.Errorf("failed to read script %q: %w", file, err)
+ }
+
+ client := &http.Client{}
+
+ script, err := Compile(string(src), nil)
+ if err != nil {
+ return fmt.Errorf("failed to compile script: %w", err)
+ }
+
+ scraper := NewScraper()
+ scraper.ScrapeFunc = script.Scrape
+ scraper.Script = file
+ scraper.Client = client
+ scraper.Modules = LoadModules(script.Config())
+
+ scraper.Run()
+ return nil
+}
+
+func Dev(file string) error {
+ cachefile, err := newCacheFile()
+ if err != nil {
+ return fmt.Errorf("failed to create cache file: %w", err)
+ }
+
+ trapsignal(func() {
+ os.RemoveAll(cachefile)
+ })
+
+ fn := func(s string) error {
+ client := &http.Client{}
+
+ script, err := Compile(s, nil)
+ if err != nil {
+ printCompileErr(file, err)
+ return nil
+ }
+
+ cfg := script.Config()
+ cfg = updateCfg(cfg, "depth", 0)
+ cfg = updateCfg(cfg, "cache", "file:"+cachefile)
+
+ scraper := NewScraper()
+ scraper.ScrapeFunc = script.Scrape
+ scraper.Script = file
+ scraper.Client = client
+ scraper.Modules = LoadModules(cfg)
+
+ screen.Clear()
+ screen.MoveTopLeft()
+ scraper.Run()
+
+ return nil
+ }
+
+ if err := Watch(file, fn); err != nil && err != StopWatch {
+ return fmt.Errorf("failed to watch script %q: %w", file, err)
+ }
+ return nil
+}
+
+func printCompileErr(script string, err error) {
+ screen.Clear()
+ screen.MoveTopLeft()
+
+ if errs, ok := err.(interface{ Unwrap() []error }); ok {
+ for _, err := range errs.Unwrap() {
+ log.Printf("%s:%v\n", script, err)
+ }
+ } else {
+ log.Println(err)
+ }
+}
+
+func updateCfg(cfg Config, key string, value any) Config {
+ var m map[string]any
+ if err := json.Unmarshal(cfg, &m); err != nil {
+ return cfg
+ }
+
+ m[key] = value
+
+ b, err := json.Marshal(m)
+ if err != nil {
+ return cfg
+ }
+
+ return b
+}
+
+func newCacheFile() (string, error) {
+ cachedir, err := os.MkdirTemp("", "flyscrape-cache")
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(cachedir, "dev.cache"), nil
+}
+
+func trapsignal(f func()) {
+ sig := make(chan os.Signal, 2)
+ signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
+
+ go func() {
+ <-sig
+ f()
+ os.Exit(0)
+ }()
+}
diff --git a/js.go b/js.go
index d36f98a..7b5630b 100644
--- a/js.go
+++ b/js.go
@@ -11,10 +11,8 @@ import (
"fmt"
"log"
"net/url"
- "strconv"
"strings"
"sync"
- "sync/atomic"
"github.com/PuerkitoBio/goquery"
"github.com/dop251/goja"
@@ -45,19 +43,32 @@ func (err TransformError) Error() string {
return fmt.Sprintf("%d:%d: %s", err.Line, err.Column, err.Text)
}
-func Compile(src string) (Config, ScrapeFunc, error) {
+type Exports map[string]any
+
+func (e Exports) Config() []byte {
+ b, _ := json.Marshal(e["config"])
+ return b
+}
+
+func (e Exports) Scrape(p ScrapeParams) (any, error) {
+ fn := e["__scrape"].(ScrapeFunc)
+ return fn(p)
+}
+
+type Imports map[string]map[string]any
+
+func Compile(src string, imports Imports) (Exports, error) {
src, err := build(src)
if err != nil {
- return nil, nil, err
+ return nil, err
}
-
- return vm(src)
+ return vm(src, imports)
}
func build(src string) (string, error) {
res := api.Transform(src, api.TransformOptions{
- Platform: api.PlatformBrowser,
- Format: api.FormatIIFE,
+ Platform: api.PlatformNode,
+ Format: api.FormatCommonJS,
})
var errs []error
@@ -76,31 +87,67 @@ func build(src string) (string, error) {
return string(res.Code), nil
}
-func vm(src string) (Config, ScrapeFunc, error) {
+func vm(src string, imports Imports) (Exports, error) {
vm := goja.New()
-
registry := &require.Registry{}
- registry.Enable(vm)
+ registry.Enable(vm)
console.Enable(vm)
- if _, err := vm.RunString(removeIIFE(src)); err != nil {
- return nil, nil, fmt.Errorf("running user script: %w", err)
+ for module, pkg := range imports {
+ pkg := pkg
+ registry.RegisterNativeModule(module, func(vm *goja.Runtime, o *goja.Object) {
+ exports := vm.NewObject()
+
+ for ident, val := range pkg {
+ exports.Set(ident, val)
+ }
+
+ o.Set("exports", exports)
+ })
+ }
+
+ if _, err := vm.RunString("module = {}"); err != nil {
+ return nil, fmt.Errorf("running defining module: %w", err)
+ }
+ if _, err := vm.RunString(src); err != nil {
+ return nil, fmt.Errorf("running user script: %w", err)
}
- cfg, err := vm.RunString("JSON.stringify(config)")
+ v, err := vm.RunString("module.exports")
if err != nil {
- return nil, nil, fmt.Errorf("reading config: %w", err)
+ return nil, fmt.Errorf("reading config: %w", err)
+ }
+
+ exports := Exports{}
+ obj := v.ToObject(vm)
+ for _, key := range obj.Keys() {
+ exports[key] = obj.Get(key).Export()
}
- var c atomic.Uint64
+ exports["__scrape"] = scrape(vm)
+
+ return exports, nil
+}
+
+func scrape(vm *goja.Runtime) ScrapeFunc {
var lock sync.Mutex
- scrape := func(p ScrapeParams) (any, error) {
+ defaultfn, err := vm.RunString("module.exports.default")
+ if err != nil {
+ return func(ScrapeParams) (any, error) { return nil, errors.New("no scrape function defined") }
+ }
+
+ scrapefn, ok := defaultfn.Export().(func(goja.FunctionCall) goja.Value)
+ if !ok {
+ return func(ScrapeParams) (any, error) { return nil, errors.New("default export is not a function") }
+ }
+
+ return func(p ScrapeParams) (any, error) {
lock.Lock()
defer lock.Unlock()
- doc, err := goquery.NewDocumentFromReader(strings.NewReader(p.HTML))
+ doc, err := DocumentFromString(p.HTML)
if err != nil {
log.Println(err)
return nil, err
@@ -112,37 +159,35 @@ func vm(src string) (Config, ScrapeFunc, error) {
return nil, err
}
- suffix := strconv.FormatUint(c.Add(1), 10)
- vm.Set("url_"+suffix, p.URL)
- vm.Set("doc_"+suffix, wrap(vm, doc.Selection))
- vm.Set("absurl_"+suffix, func(ref string) string {
+ absoluteURL := func(ref string) string {
abs, err := baseurl.Parse(ref)
if err != nil {
log.Println(err)
return ref
}
return abs.String()
- })
-
- data, err := vm.RunString(fmt.Sprintf(`JSON.stringify(stdin_default({doc: doc_%s, url: url_%s, absoluteURL: absurl_%s}))`, suffix, suffix, suffix))
- if err != nil {
- log.Println(err)
- return nil, err
}
- var obj any
- if err := json.Unmarshal([]byte(data.String()), &obj); err != nil {
- log.Println(err)
- return nil, err
- }
+ o := vm.NewObject()
+ o.Set("url", p.URL)
+ o.Set("doc", doc)
+ o.Set("absoluteURL", absoluteURL)
+
+ ret := scrapefn(goja.FunctionCall{Arguments: []goja.Value{o}}).Export()
+ return ret, nil
+ }
+}
- return obj, nil
+func DocumentFromString(s string) (map[string]any, error) {
+ doc, err := goquery.NewDocumentFromReader(strings.NewReader(s))
+ if err != nil {
+ return nil, err
}
- return Config(cfg.String()), scrape, nil
+ return Document(doc.Selection), nil
}
-func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any {
+func Document(sel *goquery.Selection) map[string]any {
o := map[string]any{}
o["WARNING"] = "Forgot to call text(), html() or attr()?"
o["text"] = sel.Text
@@ -151,19 +196,19 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any {
o["hasAttr"] = func(name string) bool { _, ok := sel.Attr(name); return ok }
o["hasClass"] = sel.HasClass
o["length"] = sel.Length()
- o["first"] = func() map[string]any { return wrap(vm, sel.First()) }
- o["last"] = func() map[string]any { return wrap(vm, sel.Last()) }
- o["get"] = func(index int) map[string]any { return wrap(vm, sel.Eq(index)) }
- o["find"] = func(s string) map[string]any { return wrap(vm, sel.Find(s)) }
- o["next"] = func() map[string]any { return wrap(vm, sel.Next()) }
- o["prev"] = func() map[string]any { return wrap(vm, sel.Prev()) }
- o["siblings"] = func() map[string]any { return wrap(vm, sel.Siblings()) }
- o["children"] = func() map[string]any { return wrap(vm, sel.Children()) }
- o["parent"] = func() map[string]any { return wrap(vm, sel.Parent()) }
+ o["first"] = func() map[string]any { return Document(sel.First()) }
+ o["last"] = func() map[string]any { return Document(sel.Last()) }
+ o["get"] = func(index int) map[string]any { return Document(sel.Eq(index)) }
+ o["find"] = func(s string) map[string]any { return Document(sel.Find(s)) }
+ o["next"] = func() map[string]any { return Document(sel.Next()) }
+ o["prev"] = func() map[string]any { return Document(sel.Prev()) }
+ o["siblings"] = func() map[string]any { return Document(sel.Siblings()) }
+ o["children"] = func() map[string]any { return Document(sel.Children()) }
+ o["parent"] = func() map[string]any { return Document(sel.Parent()) }
o["map"] = func(callback func(map[string]any, int) any) []any {
var vals []any
sel.Map(func(i int, s *goquery.Selection) string {
- vals = append(vals, callback(wrap(vm, s), i))
+ vals = append(vals, callback(Document(s), i))
return ""
})
return vals
@@ -171,7 +216,7 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any {
o["filter"] = func(callback func(map[string]any, int) bool) []any {
var vals []any
sel.Each(func(i int, s *goquery.Selection) {
- el := wrap(vm, s)
+ el := Document(s)
ok := callback(el, i)
if ok {
vals = append(vals, el)
@@ -181,9 +226,3 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any {
}
return o
}
-
-func removeIIFE(s string) string {
- s = strings.TrimPrefix(s, "(() => {\n")
- s = strings.TrimSuffix(s, "})();\n")
- return s
-}
diff --git a/js_test.go b/js_test.go
index 2cf8f25..acefa38 100644
--- a/js_test.go
+++ b/js_test.go
@@ -8,6 +8,7 @@ import (
"encoding/json"
"testing"
+ "github.com/dop251/goja"
"github.com/philippta/flyscrape"
"github.com/stretchr/testify/require"
)
@@ -37,12 +38,12 @@ export default function({ doc, url }) {
`
func TestJSScrape(t *testing.T) {
- cfg, run, err := flyscrape.Compile(script)
+ exports, err := flyscrape.Compile(script, nil)
require.NoError(t, err)
- require.NotNil(t, cfg)
- require.NotNil(t, run)
+ require.NotNil(t, exports)
+ require.NotEmpty(t, exports.Config)
- result, err := run(flyscrape.ScrapeParams{
+ result, err := exports.Scrape(flyscrape.ScrapeParams{
HTML: html,
URL: "http://localhost/",
})
@@ -56,11 +57,89 @@ func TestJSScrape(t *testing.T) {
require.Equal(t, "http://localhost/", m["url"])
}
+func TestJSScrapeObject(t *testing.T) {
+ js := `
+ export default function() {
+ return {foo: "bar"}
+ }
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+
+ result, err := exports.Scrape(flyscrape.ScrapeParams{
+ HTML: html,
+ URL: "http://localhost/",
+ })
+ require.NoError(t, err)
+
+ m, ok := result.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "bar", m["foo"])
+}
+
+func TestJSScrapeNull(t *testing.T) {
+ js := `
+ export default function() {
+ return null
+ }
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+
+ result, err := exports.Scrape(flyscrape.ScrapeParams{
+ HTML: html,
+ URL: "http://localhost/",
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
+
+func TestJSScrapeString(t *testing.T) {
+ js := `
+ export default function() {
+ return "foo"
+ }
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+
+ result, err := exports.Scrape(flyscrape.ScrapeParams{
+ HTML: html,
+ URL: "http://localhost/",
+ })
+ require.NoError(t, err)
+
+ m, ok := result.(string)
+ require.True(t, ok)
+ require.Equal(t, "foo", m)
+}
+
+func TestJSScrapeArray(t *testing.T) {
+ js := `
+ export default function() {
+ return [1,2,3]
+ }
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+
+ result, err := exports.Scrape(flyscrape.ScrapeParams{
+ HTML: html,
+ URL: "http://localhost/",
+ })
+ require.NoError(t, err)
+
+ m, ok := result.([]any)
+ require.True(t, ok)
+ require.Equal(t, int64(1), m[0])
+ require.Equal(t, int64(2), m[1])
+ require.Equal(t, int64(3), m[2])
+}
+
func TestJSCompileError(t *testing.T) {
- cfg, run, err := flyscrape.Compile("import foo;")
+ exports, err := flyscrape.Compile("import foo;", nil)
require.Error(t, err)
- require.Empty(t, cfg)
- require.Nil(t, run)
+ require.Nil(t, exports)
var terr flyscrape.TransformError
require.ErrorAs(t, err, &terr)
@@ -81,8 +160,10 @@ func TestJSConfig(t *testing.T) {
}
export default function() {}
`
- rawCfg, _, err := flyscrape.Compile(js)
+ exports, err := flyscrape.Compile(js, nil)
require.NoError(t, err)
+ require.NotNil(t, exports)
+ require.NotEmpty(t, exports.Config())
type config struct {
URL string `json:"url"`
@@ -91,7 +172,7 @@ func TestJSConfig(t *testing.T) {
}
var cfg config
- err = json.Unmarshal(rawCfg, &cfg)
+ err = json.Unmarshal(exports.Config(), &cfg)
require.NoError(t, err)
require.Equal(t, config{
@@ -100,3 +181,66 @@ func TestJSConfig(t *testing.T) {
AllowedDomains: []string{"example.com"},
}, cfg)
}
+
+func TestJSImports(t *testing.T) {
+ js := `
+ import A from "pkg-a"
+ import { bar } from "pkg-a/pkg-b"
+
+ export const config = {}
+ export default function() {}
+
+ export const a = A.foo
+ export const b = bar()
+ `
+ imports := flyscrape.Imports{
+ "pkg-a": map[string]any{
+ "foo": 10,
+ },
+ "pkg-a/pkg-b": map[string]any{
+ "bar": func() string {
+ return "baz"
+ },
+ },
+ }
+
+ exports, err := flyscrape.Compile(js, imports)
+ require.NoError(t, err)
+ require.NotNil(t, exports)
+
+ require.Equal(t, int64(10), exports["a"].(int64))
+ require.Equal(t, "baz", exports["b"].(string))
+}
+
+func TestJSArbitraryFunction(t *testing.T) {
+ js := `
+ export const config = {}
+ export default function() {}
+ export function foo() {
+ return "bar";
+ }
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+ require.NotNil(t, exports)
+
+ foo := func() string {
+ fn := exports["foo"].(func(goja.FunctionCall) goja.Value)
+ return fn(goja.FunctionCall{}).String()
+ }
+
+ require.Equal(t, "bar", foo())
+}
+
+func TestJSArbitraryConstString(t *testing.T) {
+ js := `
+ export const config = {}
+ export default function() {}
+ export const foo = "bar"
+ `
+ exports, err := flyscrape.Compile(js, nil)
+ require.NoError(t, err)
+ require.NotNil(t, exports)
+
+ require.Equal(t, "bar", exports["foo"].(string))
+}
diff --git a/mock.go b/mock.go
deleted file mode 100644
index 147ca82..0000000
--- a/mock.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package flyscrape
-
-import (
- "fmt"
- "io"
- "net/http"
- "strings"
-)
-
-func MockTransport(statusCode int, html string) RoundTripFunc {
- return func(*http.Request) (*http.Response, error) {
- return MockResponse(statusCode, html)
- }
-}
-
-func MockResponse(statusCode int, html string) (*http.Response, error) {
- return &http.Response{
- StatusCode: statusCode,
- Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)),
- Body: io.NopCloser(strings.NewReader(html)),
- }, nil
-}
diff --git a/module.go b/module.go
index 0540c91..9b33de4 100644
--- a/module.go
+++ b/module.go
@@ -54,11 +54,12 @@ func RegisterModule(mod Module) {
modules[mod.ModuleInfo().ID] = mod
}
-func LoadModules(s *Scraper, cfg Config) {
+func LoadModules(cfg Config) []Module {
modulesMu.RLock()
defer modulesMu.RUnlock()
loaded := map[string]struct{}{}
+ mods := []Module{}
// load standard modules in order
for _, id := range moduleOrder {
@@ -66,7 +67,7 @@ func LoadModules(s *Scraper, cfg Config) {
if err := json.Unmarshal(cfg, mod); err != nil {
panic("failed to decode config: " + err.Error())
}
- s.LoadModule(mod)
+ mods = append(mods, mod)
loaded[id] = struct{}{}
}
@@ -79,9 +80,11 @@ func LoadModules(s *Scraper, cfg Config) {
if err := json.Unmarshal(cfg, mod); err != nil {
panic("failed to decode config: " + err.Error())
}
- s.LoadModule(mod)
+ mods = append(mods, mod)
loaded[id] = struct{}{}
}
+
+ return mods
}
var (
diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go
index 10b67e9..a596eb4 100644
--- a/modules/depth/depth_test.go
+++ b/modules/depth/depth_test.go
@@ -21,31 +21,34 @@ func TestDepth(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&depth.Module{Depth: 2})
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
- switch r.URL.String() {
- case "http://www.example.com":
- return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`)
- case "http://www.google.com":
- return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
- case "http://www.duckduckgo.com":
- return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`)
- }
- return flyscrape.MockResponse(200, "")
- })
- },
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ &depth.Module{Depth: 2},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ switch r.URL.String() {
+ case "http://www.example.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`)
+ case "http://www.google.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ case "http://www.duckduckgo.com":
+ return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`)
+ }
+ return flyscrape.MockResponse(200, "")
+ })
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 3)
diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go
index a1c8401..ace9430 100644
--- a/modules/domainfilter/domainfilter_test.go
+++ b/modules/domainfilter/domainfilter_test.go
@@ -21,26 +21,29 @@ func TestDomainfilterAllowed(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&domainfilter.Module{
- URL: "http://www.example.com",
- AllowedDomains: []string{"www.google.com"},
- })
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ &domainfilter.Module{
+ URL: "http://www.example.com",
+ AllowedDomains: []string{"www.google.com"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="http://www.google.com">Google</a>
<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 2)
@@ -52,26 +55,29 @@ func TestDomainfilterAllowedAll(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&domainfilter.Module{
- URL: "http://www.example.com",
- AllowedDomains: []string{"*"},
- })
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ &domainfilter.Module{
+ URL: "http://www.example.com",
+ AllowedDomains: []string{"*"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="http://www.google.com">Google</a>
<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 3)
@@ -84,27 +90,30 @@ func TestDomainfilterBlocked(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&domainfilter.Module{
- URL: "http://www.example.com",
- AllowedDomains: []string{"*"},
- BlockedDomains: []string{"www.google.com"},
- })
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ &domainfilter.Module{
+ URL: "http://www.example.com",
+ AllowedDomains: []string{"*"},
+ BlockedDomains: []string{"www.google.com"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="http://www.google.com">Google</a>
<a href="http://www.duckduckgo.com">DuckDuckGo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 2)
diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go
index f3eb4fe..af186f9 100644
--- a/modules/followlinks/followlinks_test.go
+++ b/modules/followlinks/followlinks_test.go
@@ -20,24 +20,26 @@ func TestFollowLinks(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.LoadModule(&followlinks.Module{})
-
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/foo/bar"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="/baz">Baz</a>
<a href="baz">Baz</a>
<a href="http://www.google.com">Google</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 5)
@@ -52,28 +54,30 @@ func TestFollowSelector(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.LoadModule(&followlinks.Module{
- Follow: []string{".next a[href]"},
- })
-
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/foo/bar"},
+ &followlinks.Module{
+ Follow: []string{".next a[href]"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="/baz">Baz</a>
<a href="baz">Baz</a>
<div class="next">
<a href="http://www.google.com">Google</a>
</div>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 2)
@@ -85,26 +89,28 @@ func TestFollowDataAttr(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.LoadModule(&followlinks.Module{
- Follow: []string{"[data-url]"},
- })
-
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/foo/bar"},
+ &followlinks.Module{
+ Follow: []string{"[data-url]"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="/baz">Baz</a>
<a href="baz">Baz</a>
<div data-url="http://www.google.com">Google</div>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 2)
@@ -116,26 +122,28 @@ func TestFollowMultiple(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.LoadModule(&followlinks.Module{
- Follow: []string{"a.prev", "a.next"},
- })
-
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/foo/bar"},
+ &followlinks.Module{
+ Follow: []string{"a.prev", "a.next"},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="/baz">Baz</a>
<a class="prev" href="a">a</a>
<a class="next" href="b">b</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 3)
diff --git a/modules/proxy/proxy_test.go b/modules/proxy/proxy_test.go
index e6058b8..62da23a 100644
--- a/modules/proxy/proxy_test.go
+++ b/modules/proxy/proxy_test.go
@@ -20,13 +20,17 @@ func TestProxy(t *testing.T) {
p := newProxy(func() { called = true })
defer p.Close()
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&proxy.Module{
- Proxies: []string{p.URL},
- })
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &proxy.Module{
+ Proxies: []string{p.URL},
+ },
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
+
require.True(t, called)
}
diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go
index 1fe22b1..7be29a1 100644
--- a/modules/ratelimit/ratelimit_test.go
+++ b/modules/ratelimit/ratelimit_test.go
@@ -20,22 +20,25 @@ import (
func TestRatelimit(t *testing.T) {
var times []time.Time
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `<a href="foo">foo</a>`)
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `<a href="foo">foo</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ times = append(times, time.Now())
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- times = append(times, time.Now())
+ &ratelimit.Module{
+ Rate: 100,
},
- })
- scraper.LoadModule(&ratelimit.Module{
- Rate: 100,
- })
+ }
start := time.Now()
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
first := times[0].Add(-10 * time.Millisecond)
diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go
index 86e4ad7..78efa6a 100644
--- a/modules/starturl/starturl_test.go
+++ b/modules/starturl/starturl_test.go
@@ -18,18 +18,21 @@ func TestStartURL(t *testing.T) {
var url string
var depth int
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"})
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, "")
- },
- BuildRequestFn: func(r *flyscrape.Request) {
- url = r.URL
- depth = r.Depth
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/foo/bar"},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, "")
+ },
+ BuildRequestFn: func(r *flyscrape.Request) {
+ url = r.URL
+ depth = r.Depth
+ },
},
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Equal(t, "http://www.example.com/foo/bar", url)
diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go
index 9ebb8a5..442780d 100644
--- a/modules/urlfilter/urlfilter_test.go
+++ b/modules/urlfilter/urlfilter_test.go
@@ -21,28 +21,31 @@ func TestURLFilterAllowed(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&urlfilter.Module{
- URL: "http://www.example.com/",
- AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`},
- })
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/"},
+ &followlinks.Module{},
+ &urlfilter.Module{
+ URL: "http://www.example.com/",
+ AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="foo?id=123">123</a>
<a href="foo?id=ABC">ABC</a>
<a href="/bar">bar</a>
<a href="/barz">barz</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 3)
@@ -55,28 +58,31 @@ func TestURLFilterBlocked(t *testing.T) {
var urls []string
var mu sync.Mutex
- scraper := flyscrape.NewScraper()
- scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"})
- scraper.LoadModule(&followlinks.Module{})
- scraper.LoadModule(&urlfilter.Module{
- URL: "http://www.example.com/",
- BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`},
- })
- scraper.LoadModule(hook.Module{
- AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
- return flyscrape.MockTransport(200, `
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com/"},
+ &followlinks.Module{},
+ &urlfilter.Module{
+ URL: "http://www.example.com/",
+ BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`},
+ },
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.MockTransport(200, `
<a href="foo?id=123">123</a>
<a href="foo?id=ABC">ABC</a>
<a href="/bar">bar</a>
<a href="/barz">barz</a>`)
+ },
+ ReceiveResponseFn: func(r *flyscrape.Response) {
+ mu.Lock()
+ urls = append(urls, r.Request.URL)
+ mu.Unlock()
+ },
},
- ReceiveResponseFn: func(r *flyscrape.Response) {
- mu.Lock()
- urls = append(urls, r.Request.URL)
- mu.Unlock()
- },
- })
+ }
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
scraper.Run()
require.Len(t, urls, 3)
diff --git a/scrape.go b/scrape.go
index 00a74bf..764ef39 100644
--- a/scrape.go
+++ b/scrape.go
@@ -9,7 +9,6 @@ import (
"log"
"net/http"
"net/http/cookiejar"
- "slices"
"sync"
"github.com/cornelk/hashmap"
@@ -22,7 +21,6 @@ type Context interface {
Visit(url string)
MarkVisited(url string)
MarkUnvisited(url string)
- DisableModule(id string)
}
type Request struct {
@@ -50,39 +48,18 @@ type target struct {
}
func NewScraper() *Scraper {
- return &Scraper{
- jobs: make(chan target, 1024),
- visited: hashmap.New[string, struct{}](),
- }
+ return &Scraper{}
}
type Scraper struct {
ScrapeFunc ScrapeFunc
Script string
+ Modules []Module
+ Client *http.Client
wg sync.WaitGroup
jobs chan target
visited *hashmap.Map[string, struct{}]
-
- modules []Module
- moduleIDs []string
- client *http.Client
-}
-
-func (s *Scraper) LoadModule(mod Module) {
- id := mod.ModuleInfo().ID
-
- s.modules = append(s.modules, mod)
- s.moduleIDs = append(s.moduleIDs, id)
-}
-
-func (s *Scraper) DisableModule(id string) {
- idx := slices.Index(s.moduleIDs, id)
- if idx == -1 {
- return
- }
- s.modules = slices.Delete(s.modules, idx, idx+1)
- s.moduleIDs = slices.Delete(s.moduleIDs, idx, idx+1)
}
func (s *Scraper) Visit(url string) {
@@ -102,18 +79,28 @@ func (s *Scraper) ScriptName() string {
}
func (s *Scraper) Run() {
- for _, mod := range s.modules {
+ s.jobs = make(chan target, 1024)
+ s.visited = hashmap.New[string, struct{}]()
+
+ s.initClient()
+
+ for _, mod := range s.Modules {
if v, ok := mod.(Provisioner); ok {
v.Provision(s)
}
}
- s.initClient()
+ for _, mod := range s.Modules {
+ if v, ok := mod.(TransportAdapter); ok {
+ s.Client.Transport = v.AdaptTransport(s.Client.Transport)
+ }
+ }
+
go s.scrape()
s.wg.Wait()
close(s.jobs)
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(Finalizer); ok {
v.Finalize()
}
@@ -121,13 +108,14 @@ func (s *Scraper) Run() {
}
func (s *Scraper) initClient() {
- jar, _ := cookiejar.New(nil)
- s.client = &http.Client{Jar: jar, Transport: http.DefaultTransport}
-
- for _, mod := range s.modules {
- if v, ok := mod.(TransportAdapter); ok {
- s.client.Transport = v.AdaptTransport(s.client.Transport)
- }
+ if s.Client == nil {
+ s.Client = &http.Client{}
+ }
+ if s.Client.Jar == nil {
+ s.Client.Jar, _ = cookiejar.New(nil)
+ }
+ if s.Client.Transport == nil {
+ s.Client.Transport = http.DefaultTransport
}
}
@@ -146,7 +134,7 @@ func (s *Scraper) process(url string, depth int) {
Method: http.MethodGet,
URL: url,
Headers: defaultHeaders(),
- Cookies: s.client.Jar,
+ Cookies: s.Client.Jar,
Depth: depth,
}
@@ -157,7 +145,7 @@ func (s *Scraper) process(url string, depth int) {
},
}
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(RequestBuilder); ok {
v.BuildRequest(request)
}
@@ -170,7 +158,7 @@ func (s *Scraper) process(url string, depth int) {
}
req.Header = request.Headers
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(RequestValidator); ok {
if !v.ValidateRequest(request) {
return
@@ -179,14 +167,14 @@ func (s *Scraper) process(url string, depth int) {
}
defer func() {
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(ResponseReceiver); ok {
v.ReceiveResponse(response)
}
}
}()
- resp, err := s.client.Do(req)
+ resp, err := s.Client.Do(req)
if err != nil {
response.Error = err
return
diff --git a/utils.go b/utils.go
index faa4937..161cff8 100644
--- a/utils.go
+++ b/utils.go
@@ -7,6 +7,8 @@ package flyscrape
import (
"bytes"
"encoding/json"
+ "fmt"
+ "io"
"net/http"
"strings"
)
@@ -25,3 +27,17 @@ type RoundTripFunc func(*http.Request) (*http.Response, error)
func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
+
+func MockTransport(statusCode int, html string) RoundTripFunc {
+ return func(*http.Request) (*http.Response, error) {
+ return MockResponse(statusCode, html)
+ }
+}
+
+func MockResponse(statusCode int, html string) (*http.Response, error) {
+ return &http.Response{
+ StatusCode: statusCode,
+ Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)),
+ Body: io.NopCloser(strings.NewReader(html)),
+ }, nil
+}