summaryrefslogtreecommitdiff
path: root/scrape.go
diff options
context:
space:
mode:
Diffstat (limited to 'scrape.go')
-rw-r--r--scrape.go70
1 files changed, 29 insertions, 41 deletions
diff --git a/scrape.go b/scrape.go
index 00a74bf..764ef39 100644
--- a/scrape.go
+++ b/scrape.go
@@ -9,7 +9,6 @@ import (
"log"
"net/http"
"net/http/cookiejar"
- "slices"
"sync"
"github.com/cornelk/hashmap"
@@ -22,7 +21,6 @@ type Context interface {
Visit(url string)
MarkVisited(url string)
MarkUnvisited(url string)
- DisableModule(id string)
}
type Request struct {
@@ -50,39 +48,18 @@ type target struct {
}
func NewScraper() *Scraper {
- return &Scraper{
- jobs: make(chan target, 1024),
- visited: hashmap.New[string, struct{}](),
- }
+ return &Scraper{}
}
type Scraper struct {
ScrapeFunc ScrapeFunc
Script string
+ Modules []Module
+ Client *http.Client
wg sync.WaitGroup
jobs chan target
visited *hashmap.Map[string, struct{}]
-
- modules []Module
- moduleIDs []string
- client *http.Client
-}
-
-func (s *Scraper) LoadModule(mod Module) {
- id := mod.ModuleInfo().ID
-
- s.modules = append(s.modules, mod)
- s.moduleIDs = append(s.moduleIDs, id)
-}
-
-func (s *Scraper) DisableModule(id string) {
- idx := slices.Index(s.moduleIDs, id)
- if idx == -1 {
- return
- }
- s.modules = slices.Delete(s.modules, idx, idx+1)
- s.moduleIDs = slices.Delete(s.moduleIDs, idx, idx+1)
}
func (s *Scraper) Visit(url string) {
@@ -102,18 +79,28 @@ func (s *Scraper) ScriptName() string {
}
func (s *Scraper) Run() {
- for _, mod := range s.modules {
+ s.jobs = make(chan target, 1024)
+ s.visited = hashmap.New[string, struct{}]()
+
+ s.initClient()
+
+ for _, mod := range s.Modules {
if v, ok := mod.(Provisioner); ok {
v.Provision(s)
}
}
- s.initClient()
+ for _, mod := range s.Modules {
+ if v, ok := mod.(TransportAdapter); ok {
+ s.Client.Transport = v.AdaptTransport(s.Client.Transport)
+ }
+ }
+
go s.scrape()
s.wg.Wait()
close(s.jobs)
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(Finalizer); ok {
v.Finalize()
}
@@ -121,13 +108,14 @@ func (s *Scraper) Run() {
}
func (s *Scraper) initClient() {
- jar, _ := cookiejar.New(nil)
- s.client = &http.Client{Jar: jar, Transport: http.DefaultTransport}
-
- for _, mod := range s.modules {
- if v, ok := mod.(TransportAdapter); ok {
- s.client.Transport = v.AdaptTransport(s.client.Transport)
- }
+ if s.Client == nil {
+ s.Client = &http.Client{}
+ }
+ if s.Client.Jar == nil {
+ s.Client.Jar, _ = cookiejar.New(nil)
+ }
+ if s.Client.Transport == nil {
+ s.Client.Transport = http.DefaultTransport
}
}
@@ -146,7 +134,7 @@ func (s *Scraper) process(url string, depth int) {
Method: http.MethodGet,
URL: url,
Headers: defaultHeaders(),
- Cookies: s.client.Jar,
+ Cookies: s.Client.Jar,
Depth: depth,
}
@@ -157,7 +145,7 @@ func (s *Scraper) process(url string, depth int) {
},
}
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(RequestBuilder); ok {
v.BuildRequest(request)
}
@@ -170,7 +158,7 @@ func (s *Scraper) process(url string, depth int) {
}
req.Header = request.Headers
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(RequestValidator); ok {
if !v.ValidateRequest(request) {
return
@@ -179,14 +167,14 @@ func (s *Scraper) process(url string, depth int) {
}
defer func() {
- for _, mod := range s.modules {
+ for _, mod := range s.Modules {
if v, ok := mod.(ResponseReceiver); ok {
v.ReceiveResponse(response)
}
}
}()
- resp, err := s.client.Do(req)
+ resp, err := s.Client.Do(req)
if err != nil {
response.Error = err
return