diff options
| -rw-r--r-- | cmd/flyscrape/dev.go | 95 | ||||
| -rw-r--r-- | cmd/flyscrape/new.go | 4 | ||||
| -rw-r--r-- | cmd/flyscrape/run.go | 39 | ||||
| -rw-r--r-- | flyscrape.go | 131 | ||||
| -rw-r--r-- | js.go | 145 | ||||
| -rw-r--r-- | js_test.go | 162 | ||||
| -rw-r--r-- | mock.go | 26 | ||||
| -rw-r--r-- | module.go | 9 | ||||
| -rw-r--r-- | modules/depth/depth_test.go | 49 | ||||
| -rw-r--r-- | modules/domainfilter/domainfilter_test.go | 107 | ||||
| -rw-r--r-- | modules/followlinks/followlinks_test.go | 124 | ||||
| -rw-r--r-- | modules/proxy/proxy_test.go | 14 | ||||
| -rw-r--r-- | modules/ratelimit/ratelimit_test.go | 27 | ||||
| -rw-r--r-- | modules/starturl/starturl_test.go | 23 | ||||
| -rw-r--r-- | modules/urlfilter/urlfilter_test.go | 70 | ||||
| -rw-r--r-- | scrape.go | 70 | ||||
| -rw-r--r-- | utils.go | 16 |
17 files changed, 660 insertions, 451 deletions
diff --git a/cmd/flyscrape/dev.go b/cmd/flyscrape/dev.go index 9ddb3bf..84a436b 100644 --- a/cmd/flyscrape/dev.go +++ b/cmd/flyscrape/dev.go @@ -5,16 +5,9 @@ package main import ( - "encoding/json" "flag" "fmt" - "log" - "os" - "os/signal" - "path/filepath" - "syscall" - "github.com/inancgumus/screen" "github.com/philippta/flyscrape" ) @@ -27,46 +20,13 @@ func (c *DevCommand) Run(args []string) error { if err := fs.Parse(args); err != nil { return err } else if fs.NArg() == 0 || fs.Arg(0) == "" { - return fmt.Errorf("script path required") + c.Usage() + return flag.ErrHelp } else if fs.NArg() > 1 { return fmt.Errorf("too many arguments") } - script := fs.Arg(0) - cachefile, err := newCacheFile() - if err != nil { - return fmt.Errorf("failed to create cache file: %w", err) - } - - trapsignal(func() { os.RemoveAll(cachefile) }) - - err = flyscrape.Watch(script, func(s string) error { - cfg, scrape, err := flyscrape.Compile(s) - if err != nil { - printCompileErr(script, err) - return nil - } - - cfg = updateCfg(cfg, "depth", 0) - cfg = updateCfg(cfg, "cache", "file:"+cachefile) - - scraper := flyscrape.NewScraper() - scraper.ScrapeFunc = scrape - scraper.Script = script - - flyscrape.LoadModules(scraper, cfg) - - screen.Clear() - screen.MoveTopLeft() - scraper.Run() - - return nil - }) - if err != nil && err != flyscrape.StopWatch { - return fmt.Errorf("failed to watch script %q: %w", script, err) - } - - return nil + return flyscrape.Dev(fs.Arg(0)) } func (c *DevCommand) Usage() { @@ -78,58 +38,9 @@ Usage: flyscrape dev SCRIPT - Examples: # Run and watch script. $ flyscrape dev example.js `[1:]) } - -func printCompileErr(script string, err error) { - screen.Clear() - screen.MoveTopLeft() - - if errs, ok := err.(interface{ Unwrap() []error }); ok { - for _, err := range errs.Unwrap() { - log.Printf("%s:%v\n", script, err) - } - } else { - log.Println(err) - } -} - -func updateCfg(cfg flyscrape.Config, key string, value any) flyscrape.Config { - var m map[string]any - if err := json.Unmarshal(cfg, &m); err != nil { - return cfg - } - - m[key] = value - - b, err := json.Marshal(m) - if err != nil { - return cfg - } - - return b -} - -func newCacheFile() (string, error) { - cachedir, err := os.MkdirTemp("", "flyscrape-cache") - if err != nil { - return "", err - } - return filepath.Join(cachedir, "dev.cache"), nil -} - -func trapsignal(f func()) { - sig := make(chan os.Signal, 2) - signal.Notify(sig, os.Interrupt, syscall.SIGTERM) - - go func() { - <-sig - f() - os.Exit(0) - }() -} diff --git a/cmd/flyscrape/new.go b/cmd/flyscrape/new.go index 8c0d6c4..4ab248e 100644 --- a/cmd/flyscrape/new.go +++ b/cmd/flyscrape/new.go @@ -21,7 +21,8 @@ func (c *NewCommand) Run(args []string) error { if err := fs.Parse(args); err != nil { return err } else if fs.NArg() == 0 || fs.Arg(0) == "" { - return fmt.Errorf("script path required") + c.Usage() + return flag.ErrHelp } else if fs.NArg() > 1 { return fmt.Errorf("too many arguments") } @@ -47,7 +48,6 @@ Usage: flyscrape new SCRIPT - Examples: # Create a new scraping script. diff --git a/cmd/flyscrape/run.go b/cmd/flyscrape/run.go index 039574b..7a8930a 100644 --- a/cmd/flyscrape/run.go +++ b/cmd/flyscrape/run.go @@ -7,12 +7,8 @@ package main import ( "flag" "fmt" - "log" - "os" - "time" "github.com/philippta/flyscrape" - "github.com/philippta/flyscrape/modules/hook" ) type RunCommand struct{} @@ -24,41 +20,13 @@ func (c *RunCommand) Run(args []string) error { if err := fs.Parse(args); err != nil { return err } else if fs.NArg() == 0 || fs.Arg(0) == "" { - return fmt.Errorf("script path required") + c.Usage() + return flag.ErrHelp } else if fs.NArg() > 1 { return fmt.Errorf("too many arguments") } - script := fs.Arg(0) - src, err := os.ReadFile(script) - if err != nil { - return fmt.Errorf("failed to read script %q: %w", script, err) - } - - cfg, scrape, err := flyscrape.Compile(string(src)) - if err != nil { - return fmt.Errorf("failed to compile script: %w", err) - } - - scraper := flyscrape.NewScraper() - scraper.ScrapeFunc = scrape - scraper.Script = script - - flyscrape.LoadModules(scraper, cfg) - - count := 0 - start := time.Now() - - scraper.LoadModule(hook.Module{ - ReceiveResponseFn: func(r *flyscrape.Response) { - count++ - }, - }) - - scraper.Run() - - log.Printf("Scraped %d websites in %v\n", count, time.Since(start)) - return nil + return flyscrape.Run(fs.Arg(0)) } func (c *RunCommand) Usage() { @@ -69,7 +37,6 @@ Usage: flyscrape run SCRIPT - Examples: # Run the script. diff --git a/flyscrape.go b/flyscrape.go new file mode 100644 index 0000000..bb4ee30 --- /dev/null +++ b/flyscrape.go @@ -0,0 +1,131 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package flyscrape + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + + "github.com/inancgumus/screen" +) + +func Run(file string) error { + src, err := os.ReadFile(file) + if err != nil { + return fmt.Errorf("failed to read script %q: %w", file, err) + } + + client := &http.Client{} + + script, err := Compile(string(src), nil) + if err != nil { + return fmt.Errorf("failed to compile script: %w", err) + } + + scraper := NewScraper() + scraper.ScrapeFunc = script.Scrape + scraper.Script = file + scraper.Client = client + scraper.Modules = LoadModules(script.Config()) + + scraper.Run() + return nil +} + +func Dev(file string) error { + cachefile, err := newCacheFile() + if err != nil { + return fmt.Errorf("failed to create cache file: %w", err) + } + + trapsignal(func() { + os.RemoveAll(cachefile) + }) + + fn := func(s string) error { + client := &http.Client{} + + script, err := Compile(s, nil) + if err != nil { + printCompileErr(file, err) + return nil + } + + cfg := script.Config() + cfg = updateCfg(cfg, "depth", 0) + cfg = updateCfg(cfg, "cache", "file:"+cachefile) + + scraper := NewScraper() + scraper.ScrapeFunc = script.Scrape + scraper.Script = file + scraper.Client = client + scraper.Modules = LoadModules(cfg) + + screen.Clear() + screen.MoveTopLeft() + scraper.Run() + + return nil + } + + if err := Watch(file, fn); err != nil && err != StopWatch { + return fmt.Errorf("failed to watch script %q: %w", file, err) + } + return nil +} + +func printCompileErr(script string, err error) { + screen.Clear() + screen.MoveTopLeft() + + if errs, ok := err.(interface{ Unwrap() []error }); ok { + for _, err := range errs.Unwrap() { + log.Printf("%s:%v\n", script, err) + } + } else { + log.Println(err) + } +} + +func updateCfg(cfg Config, key string, value any) Config { + var m map[string]any + if err := json.Unmarshal(cfg, &m); err != nil { + return cfg + } + + m[key] = value + + b, err := json.Marshal(m) + if err != nil { + return cfg + } + + return b +} + +func newCacheFile() (string, error) { + cachedir, err := os.MkdirTemp("", "flyscrape-cache") + if err != nil { + return "", err + } + return filepath.Join(cachedir, "dev.cache"), nil +} + +func trapsignal(f func()) { + sig := make(chan os.Signal, 2) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sig + f() + os.Exit(0) + }() +} @@ -11,10 +11,8 @@ import ( "fmt" "log" "net/url" - "strconv" "strings" "sync" - "sync/atomic" "github.com/PuerkitoBio/goquery" "github.com/dop251/goja" @@ -45,19 +43,32 @@ func (err TransformError) Error() string { return fmt.Sprintf("%d:%d: %s", err.Line, err.Column, err.Text) } -func Compile(src string) (Config, ScrapeFunc, error) { +type Exports map[string]any + +func (e Exports) Config() []byte { + b, _ := json.Marshal(e["config"]) + return b +} + +func (e Exports) Scrape(p ScrapeParams) (any, error) { + fn := e["__scrape"].(ScrapeFunc) + return fn(p) +} + +type Imports map[string]map[string]any + +func Compile(src string, imports Imports) (Exports, error) { src, err := build(src) if err != nil { - return nil, nil, err + return nil, err } - - return vm(src) + return vm(src, imports) } func build(src string) (string, error) { res := api.Transform(src, api.TransformOptions{ - Platform: api.PlatformBrowser, - Format: api.FormatIIFE, + Platform: api.PlatformNode, + Format: api.FormatCommonJS, }) var errs []error @@ -76,31 +87,67 @@ func build(src string) (string, error) { return string(res.Code), nil } -func vm(src string) (Config, ScrapeFunc, error) { +func vm(src string, imports Imports) (Exports, error) { vm := goja.New() - registry := &require.Registry{} - registry.Enable(vm) + registry.Enable(vm) console.Enable(vm) - if _, err := vm.RunString(removeIIFE(src)); err != nil { - return nil, nil, fmt.Errorf("running user script: %w", err) + for module, pkg := range imports { + pkg := pkg + registry.RegisterNativeModule(module, func(vm *goja.Runtime, o *goja.Object) { + exports := vm.NewObject() + + for ident, val := range pkg { + exports.Set(ident, val) + } + + o.Set("exports", exports) + }) + } + + if _, err := vm.RunString("module = {}"); err != nil { + return nil, fmt.Errorf("running defining module: %w", err) + } + if _, err := vm.RunString(src); err != nil { + return nil, fmt.Errorf("running user script: %w", err) } - cfg, err := vm.RunString("JSON.stringify(config)") + v, err := vm.RunString("module.exports") if err != nil { - return nil, nil, fmt.Errorf("reading config: %w", err) + return nil, fmt.Errorf("reading config: %w", err) + } + + exports := Exports{} + obj := v.ToObject(vm) + for _, key := range obj.Keys() { + exports[key] = obj.Get(key).Export() } - var c atomic.Uint64 + exports["__scrape"] = scrape(vm) + + return exports, nil +} + +func scrape(vm *goja.Runtime) ScrapeFunc { var lock sync.Mutex - scrape := func(p ScrapeParams) (any, error) { + defaultfn, err := vm.RunString("module.exports.default") + if err != nil { + return func(ScrapeParams) (any, error) { return nil, errors.New("no scrape function defined") } + } + + scrapefn, ok := defaultfn.Export().(func(goja.FunctionCall) goja.Value) + if !ok { + return func(ScrapeParams) (any, error) { return nil, errors.New("default export is not a function") } + } + + return func(p ScrapeParams) (any, error) { lock.Lock() defer lock.Unlock() - doc, err := goquery.NewDocumentFromReader(strings.NewReader(p.HTML)) + doc, err := DocumentFromString(p.HTML) if err != nil { log.Println(err) return nil, err @@ -112,37 +159,35 @@ func vm(src string) (Config, ScrapeFunc, error) { return nil, err } - suffix := strconv.FormatUint(c.Add(1), 10) - vm.Set("url_"+suffix, p.URL) - vm.Set("doc_"+suffix, wrap(vm, doc.Selection)) - vm.Set("absurl_"+suffix, func(ref string) string { + absoluteURL := func(ref string) string { abs, err := baseurl.Parse(ref) if err != nil { log.Println(err) return ref } return abs.String() - }) - - data, err := vm.RunString(fmt.Sprintf(`JSON.stringify(stdin_default({doc: doc_%s, url: url_%s, absoluteURL: absurl_%s}))`, suffix, suffix, suffix)) - if err != nil { - log.Println(err) - return nil, err } - var obj any - if err := json.Unmarshal([]byte(data.String()), &obj); err != nil { - log.Println(err) - return nil, err - } + o := vm.NewObject() + o.Set("url", p.URL) + o.Set("doc", doc) + o.Set("absoluteURL", absoluteURL) + + ret := scrapefn(goja.FunctionCall{Arguments: []goja.Value{o}}).Export() + return ret, nil + } +} - return obj, nil +func DocumentFromString(s string) (map[string]any, error) { + doc, err := goquery.NewDocumentFromReader(strings.NewReader(s)) + if err != nil { + return nil, err } - return Config(cfg.String()), scrape, nil + return Document(doc.Selection), nil } -func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any { +func Document(sel *goquery.Selection) map[string]any { o := map[string]any{} o["WARNING"] = "Forgot to call text(), html() or attr()?" o["text"] = sel.Text @@ -151,19 +196,19 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any { o["hasAttr"] = func(name string) bool { _, ok := sel.Attr(name); return ok } o["hasClass"] = sel.HasClass o["length"] = sel.Length() - o["first"] = func() map[string]any { return wrap(vm, sel.First()) } - o["last"] = func() map[string]any { return wrap(vm, sel.Last()) } - o["get"] = func(index int) map[string]any { return wrap(vm, sel.Eq(index)) } - o["find"] = func(s string) map[string]any { return wrap(vm, sel.Find(s)) } - o["next"] = func() map[string]any { return wrap(vm, sel.Next()) } - o["prev"] = func() map[string]any { return wrap(vm, sel.Prev()) } - o["siblings"] = func() map[string]any { return wrap(vm, sel.Siblings()) } - o["children"] = func() map[string]any { return wrap(vm, sel.Children()) } - o["parent"] = func() map[string]any { return wrap(vm, sel.Parent()) } + o["first"] = func() map[string]any { return Document(sel.First()) } + o["last"] = func() map[string]any { return Document(sel.Last()) } + o["get"] = func(index int) map[string]any { return Document(sel.Eq(index)) } + o["find"] = func(s string) map[string]any { return Document(sel.Find(s)) } + o["next"] = func() map[string]any { return Document(sel.Next()) } + o["prev"] = func() map[string]any { return Document(sel.Prev()) } + o["siblings"] = func() map[string]any { return Document(sel.Siblings()) } + o["children"] = func() map[string]any { return Document(sel.Children()) } + o["parent"] = func() map[string]any { return Document(sel.Parent()) } o["map"] = func(callback func(map[string]any, int) any) []any { var vals []any sel.Map(func(i int, s *goquery.Selection) string { - vals = append(vals, callback(wrap(vm, s), i)) + vals = append(vals, callback(Document(s), i)) return "" }) return vals @@ -171,7 +216,7 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any { o["filter"] = func(callback func(map[string]any, int) bool) []any { var vals []any sel.Each(func(i int, s *goquery.Selection) { - el := wrap(vm, s) + el := Document(s) ok := callback(el, i) if ok { vals = append(vals, el) @@ -181,9 +226,3 @@ func wrap(vm *goja.Runtime, sel *goquery.Selection) map[string]any { } return o } - -func removeIIFE(s string) string { - s = strings.TrimPrefix(s, "(() => {\n") - s = strings.TrimSuffix(s, "})();\n") - return s -} @@ -8,6 +8,7 @@ import ( "encoding/json" "testing" + "github.com/dop251/goja" "github.com/philippta/flyscrape" "github.com/stretchr/testify/require" ) @@ -37,12 +38,12 @@ export default function({ doc, url }) { ` func TestJSScrape(t *testing.T) { - cfg, run, err := flyscrape.Compile(script) + exports, err := flyscrape.Compile(script, nil) require.NoError(t, err) - require.NotNil(t, cfg) - require.NotNil(t, run) + require.NotNil(t, exports) + require.NotEmpty(t, exports.Config) - result, err := run(flyscrape.ScrapeParams{ + result, err := exports.Scrape(flyscrape.ScrapeParams{ HTML: html, URL: "http://localhost/", }) @@ -56,11 +57,89 @@ func TestJSScrape(t *testing.T) { require.Equal(t, "http://localhost/", m["url"]) } +func TestJSScrapeObject(t *testing.T) { + js := ` + export default function() { + return {foo: "bar"} + } + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + + result, err := exports.Scrape(flyscrape.ScrapeParams{ + HTML: html, + URL: "http://localhost/", + }) + require.NoError(t, err) + + m, ok := result.(map[string]any) + require.True(t, ok) + require.Equal(t, "bar", m["foo"]) +} + +func TestJSScrapeNull(t *testing.T) { + js := ` + export default function() { + return null + } + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + + result, err := exports.Scrape(flyscrape.ScrapeParams{ + HTML: html, + URL: "http://localhost/", + }) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestJSScrapeString(t *testing.T) { + js := ` + export default function() { + return "foo" + } + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + + result, err := exports.Scrape(flyscrape.ScrapeParams{ + HTML: html, + URL: "http://localhost/", + }) + require.NoError(t, err) + + m, ok := result.(string) + require.True(t, ok) + require.Equal(t, "foo", m) +} + +func TestJSScrapeArray(t *testing.T) { + js := ` + export default function() { + return [1,2,3] + } + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + + result, err := exports.Scrape(flyscrape.ScrapeParams{ + HTML: html, + URL: "http://localhost/", + }) + require.NoError(t, err) + + m, ok := result.([]any) + require.True(t, ok) + require.Equal(t, int64(1), m[0]) + require.Equal(t, int64(2), m[1]) + require.Equal(t, int64(3), m[2]) +} + func TestJSCompileError(t *testing.T) { - cfg, run, err := flyscrape.Compile("import foo;") + exports, err := flyscrape.Compile("import foo;", nil) require.Error(t, err) - require.Empty(t, cfg) - require.Nil(t, run) + require.Nil(t, exports) var terr flyscrape.TransformError require.ErrorAs(t, err, &terr) @@ -81,8 +160,10 @@ func TestJSConfig(t *testing.T) { } export default function() {} ` - rawCfg, _, err := flyscrape.Compile(js) + exports, err := flyscrape.Compile(js, nil) require.NoError(t, err) + require.NotNil(t, exports) + require.NotEmpty(t, exports.Config()) type config struct { URL string `json:"url"` @@ -91,7 +172,7 @@ func TestJSConfig(t *testing.T) { } var cfg config - err = json.Unmarshal(rawCfg, &cfg) + err = json.Unmarshal(exports.Config(), &cfg) require.NoError(t, err) require.Equal(t, config{ @@ -100,3 +181,66 @@ func TestJSConfig(t *testing.T) { AllowedDomains: []string{"example.com"}, }, cfg) } + +func TestJSImports(t *testing.T) { + js := ` + import A from "pkg-a" + import { bar } from "pkg-a/pkg-b" + + export const config = {} + export default function() {} + + export const a = A.foo + export const b = bar() + ` + imports := flyscrape.Imports{ + "pkg-a": map[string]any{ + "foo": 10, + }, + "pkg-a/pkg-b": map[string]any{ + "bar": func() string { + return "baz" + }, + }, + } + + exports, err := flyscrape.Compile(js, imports) + require.NoError(t, err) + require.NotNil(t, exports) + + require.Equal(t, int64(10), exports["a"].(int64)) + require.Equal(t, "baz", exports["b"].(string)) +} + +func TestJSArbitraryFunction(t *testing.T) { + js := ` + export const config = {} + export default function() {} + export function foo() { + return "bar"; + } + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + require.NotNil(t, exports) + + foo := func() string { + fn := exports["foo"].(func(goja.FunctionCall) goja.Value) + return fn(goja.FunctionCall{}).String() + } + + require.Equal(t, "bar", foo()) +} + +func TestJSArbitraryConstString(t *testing.T) { + js := ` + export const config = {} + export default function() {} + export const foo = "bar" + ` + exports, err := flyscrape.Compile(js, nil) + require.NoError(t, err) + require.NotNil(t, exports) + + require.Equal(t, "bar", exports["foo"].(string)) +} diff --git a/mock.go b/mock.go deleted file mode 100644 index 147ca82..0000000 --- a/mock.go +++ /dev/null @@ -1,26 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package flyscrape - -import ( - "fmt" - "io" - "net/http" - "strings" -) - -func MockTransport(statusCode int, html string) RoundTripFunc { - return func(*http.Request) (*http.Response, error) { - return MockResponse(statusCode, html) - } -} - -func MockResponse(statusCode int, html string) (*http.Response, error) { - return &http.Response{ - StatusCode: statusCode, - Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)), - Body: io.NopCloser(strings.NewReader(html)), - }, nil -} @@ -54,11 +54,12 @@ func RegisterModule(mod Module) { modules[mod.ModuleInfo().ID] = mod } -func LoadModules(s *Scraper, cfg Config) { +func LoadModules(cfg Config) []Module { modulesMu.RLock() defer modulesMu.RUnlock() loaded := map[string]struct{}{} + mods := []Module{} // load standard modules in order for _, id := range moduleOrder { @@ -66,7 +67,7 @@ func LoadModules(s *Scraper, cfg Config) { if err := json.Unmarshal(cfg, mod); err != nil { panic("failed to decode config: " + err.Error()) } - s.LoadModule(mod) + mods = append(mods, mod) loaded[id] = struct{}{} } @@ -79,9 +80,11 @@ func LoadModules(s *Scraper, cfg Config) { if err := json.Unmarshal(cfg, mod); err != nil { panic("failed to decode config: " + err.Error()) } - s.LoadModule(mod) + mods = append(mods, mod) loaded[id] = struct{}{} } + + return mods } var ( diff --git a/modules/depth/depth_test.go b/modules/depth/depth_test.go index 10b67e9..a596eb4 100644 --- a/modules/depth/depth_test.go +++ b/modules/depth/depth_test.go @@ -21,31 +21,34 @@ func TestDepth(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&depth.Module{Depth: 2}) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - switch r.URL.String() { - case "http://www.example.com": - return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`) - case "http://www.google.com": - return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`) - case "http://www.duckduckgo.com": - return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`) - } - return flyscrape.MockResponse(200, "") - }) - }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + &depth.Module{Depth: 2}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "http://www.example.com": + return flyscrape.MockResponse(200, `<a href="http://www.google.com">Google</a>`) + case "http://www.google.com": + return flyscrape.MockResponse(200, `<a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + case "http://www.duckduckgo.com": + return flyscrape.MockResponse(200, `<a href="http://www.example.com">Example</a>`) + } + return flyscrape.MockResponse(200, "") + }) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 3) diff --git a/modules/domainfilter/domainfilter_test.go b/modules/domainfilter/domainfilter_test.go index a1c8401..ace9430 100644 --- a/modules/domainfilter/domainfilter_test.go +++ b/modules/domainfilter/domainfilter_test.go @@ -21,26 +21,29 @@ func TestDomainfilterAllowed(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&domainfilter.Module{ - URL: "http://www.example.com", - AllowedDomains: []string{"www.google.com"}, - }) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + &domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"www.google.com"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="http://www.google.com">Google</a> <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 2) @@ -52,26 +55,29 @@ func TestDomainfilterAllowedAll(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&domainfilter.Module{ - URL: "http://www.example.com", - AllowedDomains: []string{"*"}, - }) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + &domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"*"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="http://www.google.com">Google</a> <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 3) @@ -84,27 +90,30 @@ func TestDomainfilterBlocked(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&domainfilter.Module{ - URL: "http://www.example.com", - AllowedDomains: []string{"*"}, - BlockedDomains: []string{"www.google.com"}, - }) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + &domainfilter.Module{ + URL: "http://www.example.com", + AllowedDomains: []string{"*"}, + BlockedDomains: []string{"www.google.com"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="http://www.google.com">Google</a> <a href="http://www.duckduckgo.com">DuckDuckGo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 2) diff --git a/modules/followlinks/followlinks_test.go b/modules/followlinks/followlinks_test.go index f3eb4fe..af186f9 100644 --- a/modules/followlinks/followlinks_test.go +++ b/modules/followlinks/followlinks_test.go @@ -20,24 +20,26 @@ func TestFollowLinks(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.LoadModule(&followlinks.Module{}) - - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/foo/bar"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="/baz">Baz</a> <a href="baz">Baz</a> <a href="http://www.google.com">Google</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 5) @@ -52,28 +54,30 @@ func TestFollowSelector(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.LoadModule(&followlinks.Module{ - Follow: []string{".next a[href]"}, - }) - - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/foo/bar"}, + &followlinks.Module{ + Follow: []string{".next a[href]"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="/baz">Baz</a> <a href="baz">Baz</a> <div class="next"> <a href="http://www.google.com">Google</a> </div>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 2) @@ -85,26 +89,28 @@ func TestFollowDataAttr(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.LoadModule(&followlinks.Module{ - Follow: []string{"[data-url]"}, - }) - - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/foo/bar"}, + &followlinks.Module{ + Follow: []string{"[data-url]"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="/baz">Baz</a> <a href="baz">Baz</a> <div data-url="http://www.google.com">Google</div>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 2) @@ -116,26 +122,28 @@ func TestFollowMultiple(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.LoadModule(&followlinks.Module{ - Follow: []string{"a.prev", "a.next"}, - }) - - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/foo/bar"}, + &followlinks.Module{ + Follow: []string{"a.prev", "a.next"}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="/baz">Baz</a> <a class="prev" href="a">a</a> <a class="next" href="b">b</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 3) diff --git a/modules/proxy/proxy_test.go b/modules/proxy/proxy_test.go index e6058b8..62da23a 100644 --- a/modules/proxy/proxy_test.go +++ b/modules/proxy/proxy_test.go @@ -20,13 +20,17 @@ func TestProxy(t *testing.T) { p := newProxy(func() { called = true }) defer p.Close() - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&proxy.Module{ - Proxies: []string{p.URL}, - }) + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &proxy.Module{ + Proxies: []string{p.URL}, + }, + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() + require.True(t, called) } diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go index 1fe22b1..7be29a1 100644 --- a/modules/ratelimit/ratelimit_test.go +++ b/modules/ratelimit/ratelimit_test.go @@ -20,22 +20,25 @@ import ( func TestRatelimit(t *testing.T) { var times []time.Time - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, `<a href="foo">foo</a>`) + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, `<a href="foo">foo</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + times = append(times, time.Now()) + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - times = append(times, time.Now()) + &ratelimit.Module{ + Rate: 100, }, - }) - scraper.LoadModule(&ratelimit.Module{ - Rate: 100, - }) + } start := time.Now() + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() first := times[0].Add(-10 * time.Millisecond) diff --git a/modules/starturl/starturl_test.go b/modules/starturl/starturl_test.go index 86e4ad7..78efa6a 100644 --- a/modules/starturl/starturl_test.go +++ b/modules/starturl/starturl_test.go @@ -18,18 +18,21 @@ func TestStartURL(t *testing.T) { var url string var depth int - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/foo/bar"}) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, "") - }, - BuildRequestFn: func(r *flyscrape.Request) { - url = r.URL - depth = r.Depth + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/foo/bar"}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, "") + }, + BuildRequestFn: func(r *flyscrape.Request) { + url = r.URL + depth = r.Depth + }, }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Equal(t, "http://www.example.com/foo/bar", url) diff --git a/modules/urlfilter/urlfilter_test.go b/modules/urlfilter/urlfilter_test.go index 9ebb8a5..442780d 100644 --- a/modules/urlfilter/urlfilter_test.go +++ b/modules/urlfilter/urlfilter_test.go @@ -21,28 +21,31 @@ func TestURLFilterAllowed(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&urlfilter.Module{ - URL: "http://www.example.com/", - AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, - }) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/"}, + &followlinks.Module{}, + &urlfilter.Module{ + URL: "http://www.example.com/", + AllowedURLs: []string{`/foo\?id=\d+`, `/bar$`}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="foo?id=123">123</a> <a href="foo?id=ABC">ABC</a> <a href="/bar">bar</a> <a href="/barz">barz</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 3) @@ -55,28 +58,31 @@ func TestURLFilterBlocked(t *testing.T) { var urls []string var mu sync.Mutex - scraper := flyscrape.NewScraper() - scraper.LoadModule(&starturl.Module{URL: "http://www.example.com/"}) - scraper.LoadModule(&followlinks.Module{}) - scraper.LoadModule(&urlfilter.Module{ - URL: "http://www.example.com/", - BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, - }) - scraper.LoadModule(hook.Module{ - AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { - return flyscrape.MockTransport(200, ` + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com/"}, + &followlinks.Module{}, + &urlfilter.Module{ + URL: "http://www.example.com/", + BlockedURLs: []string{`/foo\?id=\d+`, `/bar$`}, + }, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.MockTransport(200, ` <a href="foo?id=123">123</a> <a href="foo?id=ABC">ABC</a> <a href="/bar">bar</a> <a href="/barz">barz</a>`) + }, + ReceiveResponseFn: func(r *flyscrape.Response) { + mu.Lock() + urls = append(urls, r.Request.URL) + mu.Unlock() + }, }, - ReceiveResponseFn: func(r *flyscrape.Response) { - mu.Lock() - urls = append(urls, r.Request.URL) - mu.Unlock() - }, - }) + } + scraper := flyscrape.NewScraper() + scraper.Modules = mods scraper.Run() require.Len(t, urls, 3) @@ -9,7 +9,6 @@ import ( "log" "net/http" "net/http/cookiejar" - "slices" "sync" "github.com/cornelk/hashmap" @@ -22,7 +21,6 @@ type Context interface { Visit(url string) MarkVisited(url string) MarkUnvisited(url string) - DisableModule(id string) } type Request struct { @@ -50,39 +48,18 @@ type target struct { } func NewScraper() *Scraper { - return &Scraper{ - jobs: make(chan target, 1024), - visited: hashmap.New[string, struct{}](), - } + return &Scraper{} } type Scraper struct { ScrapeFunc ScrapeFunc Script string + Modules []Module + Client *http.Client wg sync.WaitGroup jobs chan target visited *hashmap.Map[string, struct{}] - - modules []Module - moduleIDs []string - client *http.Client -} - -func (s *Scraper) LoadModule(mod Module) { - id := mod.ModuleInfo().ID - - s.modules = append(s.modules, mod) - s.moduleIDs = append(s.moduleIDs, id) -} - -func (s *Scraper) DisableModule(id string) { - idx := slices.Index(s.moduleIDs, id) - if idx == -1 { - return - } - s.modules = slices.Delete(s.modules, idx, idx+1) - s.moduleIDs = slices.Delete(s.moduleIDs, idx, idx+1) } func (s *Scraper) Visit(url string) { @@ -102,18 +79,28 @@ func (s *Scraper) ScriptName() string { } func (s *Scraper) Run() { - for _, mod := range s.modules { + s.jobs = make(chan target, 1024) + s.visited = hashmap.New[string, struct{}]() + + s.initClient() + + for _, mod := range s.Modules { if v, ok := mod.(Provisioner); ok { v.Provision(s) } } - s.initClient() + for _, mod := range s.Modules { + if v, ok := mod.(TransportAdapter); ok { + s.Client.Transport = v.AdaptTransport(s.Client.Transport) + } + } + go s.scrape() s.wg.Wait() close(s.jobs) - for _, mod := range s.modules { + for _, mod := range s.Modules { if v, ok := mod.(Finalizer); ok { v.Finalize() } @@ -121,13 +108,14 @@ func (s *Scraper) Run() { } func (s *Scraper) initClient() { - jar, _ := cookiejar.New(nil) - s.client = &http.Client{Jar: jar, Transport: http.DefaultTransport} - - for _, mod := range s.modules { - if v, ok := mod.(TransportAdapter); ok { - s.client.Transport = v.AdaptTransport(s.client.Transport) - } + if s.Client == nil { + s.Client = &http.Client{} + } + if s.Client.Jar == nil { + s.Client.Jar, _ = cookiejar.New(nil) + } + if s.Client.Transport == nil { + s.Client.Transport = http.DefaultTransport } } @@ -146,7 +134,7 @@ func (s *Scraper) process(url string, depth int) { Method: http.MethodGet, URL: url, Headers: defaultHeaders(), - Cookies: s.client.Jar, + Cookies: s.Client.Jar, Depth: depth, } @@ -157,7 +145,7 @@ func (s *Scraper) process(url string, depth int) { }, } - for _, mod := range s.modules { + for _, mod := range s.Modules { if v, ok := mod.(RequestBuilder); ok { v.BuildRequest(request) } @@ -170,7 +158,7 @@ func (s *Scraper) process(url string, depth int) { } req.Header = request.Headers - for _, mod := range s.modules { + for _, mod := range s.Modules { if v, ok := mod.(RequestValidator); ok { if !v.ValidateRequest(request) { return @@ -179,14 +167,14 @@ func (s *Scraper) process(url string, depth int) { } defer func() { - for _, mod := range s.modules { + for _, mod := range s.Modules { if v, ok := mod.(ResponseReceiver); ok { v.ReceiveResponse(response) } } }() - resp, err := s.client.Do(req) + resp, err := s.Client.Do(req) if err != nil { response.Error = err return @@ -7,6 +7,8 @@ package flyscrape import ( "bytes" "encoding/json" + "fmt" + "io" "net/http" "strings" ) @@ -25,3 +27,17 @@ type RoundTripFunc func(*http.Request) (*http.Response, error) func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func MockTransport(statusCode int, html string) RoundTripFunc { + return func(*http.Request) (*http.Response, error) { + return MockResponse(statusCode, html) + } +} + +func MockResponse(statusCode int, html string) (*http.Response, error) { + return &http.Response{ + StatusCode: statusCode, + Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)), + Body: io.NopCloser(strings.NewReader(html)), + }, nil +} |