diff options
Diffstat (limited to 'scrape.go')
| -rw-r--r-- | scrape.go | 310 |
1 files changed, 138 insertions, 172 deletions
@@ -5,8 +5,9 @@ package flyscrape import ( + "io" "log" - "regexp" + "net/http" "strings" "sync" "time" @@ -21,17 +22,6 @@ type ScrapeParams struct { URL string } -type ScrapeOptions struct { - URL string `json:"url"` - AllowedDomains []string `json:"allowedDomains"` - BlockedDomains []string `json:"blockedDomains"` - AllowedURLs []string `json:"allowedURLs"` - BlockedURLs []string `json:"blockedURLs"` - Proxy string `json:"proxy"` - Depth int `json:"depth"` - Rate float64 `json:"rate"` -} - type ScrapeResult struct { URL string `json:"url"` Data any `json:"data,omitempty"` @@ -48,204 +38,203 @@ type ScrapeFunc func(ScrapeParams) (any, error) type FetchFunc func(url string) (string, error) -type target struct { - url string - depth int +type Visitor interface { + Visit(url string) + MarkVisited(url string) } +type ( + Request struct { + URL string + Depth int + } + + Response struct { + ScrapeResult + HTML string + Visit func(url string) + } + + target struct { + url string + depth int + } +) + type Scraper struct { - ScrapeOptions ScrapeOptions - ScrapeFunc ScrapeFunc - FetchFunc FetchFunc - - visited *hashmap.Map[string, struct{}] - wg *sync.WaitGroup - jobs chan target - results chan ScrapeResult - allowedURLsRE []*regexp.Regexp - blockedURLsRE []*regexp.Regexp + ScrapeFunc ScrapeFunc + + opts Options + wg sync.WaitGroup + jobs chan target + visited *hashmap.Map[string, struct{}] + modules *hashmap.Map[string, Module] + + canRequestHandlers []func(url string, depth int) bool + onRequestHandlers []func(*Request) + onResponseHandlers []func(*Response) + onCompleteHandlers []func() + transport func(*http.Request) (*http.Response, error) } -func (s *Scraper) init() { - s.visited = hashmap.New[string, struct{}]() - s.wg = &sync.WaitGroup{} - s.jobs = make(chan target, 1024) - s.results = make(chan ScrapeResult) +func NewScraper() *Scraper { + s := &Scraper{ + jobs: make(chan target, 1024), + visited: hashmap.New[string, struct{}](), + modules: hashmap.New[string, Module](), + transport: func(r *http.Request) (*http.Response, error) { + r.Header.Set("User-Agent", "flyscrape/0.1") + return http.DefaultClient.Do(r) + }, + } + return s +} - if s.FetchFunc == nil { - s.FetchFunc = Fetch() +func (s *Scraper) LoadModule(mod Module) { + if v, ok := mod.(Transport); ok { + s.SetTransport(v.Transport) } - if s.ScrapeOptions.Proxy != "" { - s.FetchFunc = ProxiedFetch(s.ScrapeOptions.Proxy) + + if v, ok := mod.(CanRequest); ok { + s.CanRequest(v.CanRequest) } - if s.ScrapeOptions.Rate == 0 { - s.ScrapeOptions.Rate = 100 + if v, ok := mod.(OnRequest); ok { + s.OnRequest(v.OnRequest) } - if u, err := url.Parse(s.ScrapeOptions.URL); err == nil { - s.ScrapeOptions.AllowedDomains = append(s.ScrapeOptions.AllowedDomains, u.Host()) + if v, ok := mod.(OnResponse); ok { + s.OnResponse(v.OnResponse) } - for _, pat := range s.ScrapeOptions.AllowedURLs { - re, err := regexp.Compile(pat) - if err != nil { - continue - } - s.allowedURLsRE = append(s.allowedURLsRE, re) + if v, ok := mod.(OnLoad); ok { + v.OnLoad(s) } - for _, pat := range s.ScrapeOptions.BlockedURLs { - re, err := regexp.Compile(pat) - if err != nil { - continue - } - s.blockedURLsRE = append(s.blockedURLsRE, re) + if v, ok := mod.(OnComplete); ok { + s.OnComplete(v.OnComplete) } } -func (s *Scraper) Scrape() <-chan ScrapeResult { - s.init() - s.enqueueJob(s.ScrapeOptions.URL, s.ScrapeOptions.Depth) +func (s *Scraper) Visit(url string) { + s.enqueueJob(url, 0) +} - go s.worker() - go s.waitClose() +func (s *Scraper) MarkVisited(url string) { + s.visited.Insert(url, struct{}{}) +} - return s.results +func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) { + s.transport = f } -func (s *Scraper) worker() { - var ( - rate = time.Duration(float64(time.Second) / s.ScrapeOptions.Rate) - leakyjobs = leakychan(s.jobs, rate) - ) +func (s *Scraper) CanRequest(f func(url string, depth int) bool) { + s.canRequestHandlers = append(s.canRequestHandlers, f) +} - for job := range leakyjobs { - go func(job target) { - defer s.wg.Done() +func (s *Scraper) OnRequest(f func(req *Request)) { + s.onRequestHandlers = append(s.onRequestHandlers, f) +} - res := s.process(job) - if !res.omit() { - s.results <- res - } +func (s *Scraper) OnResponse(f func(resp *Response)) { + s.onResponseHandlers = append(s.onResponseHandlers, f) +} - if job.depth <= 0 { - return - } +func (s *Scraper) OnComplete(f func()) { + s.onCompleteHandlers = append(s.onCompleteHandlers, f) +} - for _, l := range res.Links { - if _, ok := s.visited.Get(l); ok { - continue - } +func (s *Scraper) Run() { + go s.worker() + s.wg.Wait() + close(s.jobs) + + for _, handler := range s.onCompleteHandlers { + handler() + } +} - allowed := s.isDomainAllowed(l) && s.isURLAllowed(l) - if !allowed { - continue +func (s *Scraper) worker() { + for job := range s.jobs { + go func(job target) { + defer s.wg.Done() + + for _, handler := range s.canRequestHandlers { + if !handler(job.url, job.depth) { + return } + } - s.enqueueJob(l, job.depth-1) + res, html := s.process(job) + for _, handler := range s.onResponseHandlers { + handler(&Response{ + ScrapeResult: res, + HTML: html, + Visit: func(url string) { + s.enqueueJob(url, job.depth+1) + }, + }) } }(job) } } -func (s *Scraper) process(job target) (res ScrapeResult) { +func (s *Scraper) process(job target) (res ScrapeResult, html string) { res.URL = job.url res.Timestamp = time.Now() - html, err := s.FetchFunc(job.url) + req, err := http.NewRequest(http.MethodGet, job.url, nil) if err != nil { res.Error = err return } - res.Links = links(html, job.url) - res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url}) + for _, handler := range s.onRequestHandlers { + handler(&Request{URL: job.url, Depth: job.depth}) + } + + resp, err := s.transport(req) if err != nil { res.Error = err return } + defer resp.Body.Close() - return -} - -func (s *Scraper) enqueueJob(url string, depth int) { - s.wg.Add(1) - select { - case s.jobs <- target{url: url, depth: depth}: - s.visited.Set(url, struct{}{}) - default: - log.Println("queue is full, can't add url:", url) - s.wg.Done() - } -} - -func (s *Scraper) isDomainAllowed(rawurl string) bool { - u, err := url.Parse(rawurl) + body, err := io.ReadAll(resp.Body) if err != nil { - return false + res.Error = err + return } - host := u.Host() - ok := false + html = string(body) - for _, domain := range s.ScrapeOptions.AllowedDomains { - if domain == "*" || host == domain { - ok = true - break - } - } - - for _, domain := range s.ScrapeOptions.BlockedDomains { - if host == domain { - ok = false - break + if s.ScrapeFunc != nil { + res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url}) + if err != nil { + res.Error = err + return } } - return ok + return } -func (s *Scraper) isURLAllowed(rawurl string) bool { - // allow root url - if rawurl == s.ScrapeOptions.URL { - return true - } - - // allow if no filter is set - if len(s.allowedURLsRE) == 0 && len(s.blockedURLsRE) == 0 { - return true - } - - ok := false - if len(s.allowedURLsRE) == 0 { - ok = true - } - - for _, re := range s.allowedURLsRE { - if re.MatchString(rawurl) { - ok = true - break - } +func (s *Scraper) enqueueJob(url string, depth int) { + if _, ok := s.visited.Get(url); ok { + return } - for _, re := range s.blockedURLsRE { - if re.MatchString(rawurl) { - ok = false - break - } + s.wg.Add(1) + select { + case s.jobs <- target{url: url, depth: depth}: + s.MarkVisited(url) + default: + log.Println("queue is full, can't add url:", url) + s.wg.Done() } - - return ok } -func (s *Scraper) waitClose() { - s.wg.Wait() - close(s.jobs) - close(s.results) -} - -func links(html string, origin string) []string { +func ParseLinks(html string, origin string) []string { var links []string doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) if err != nil { @@ -285,26 +274,3 @@ func isValidLink(link *url.Url) bool { return true } - -func leakychan[T any](in chan T, rate time.Duration) chan T { - ticker := time.NewTicker(rate) - sem := make(chan struct{}, 1) - c := make(chan T) - - go func() { - for range ticker.C { - sem <- struct{}{} - } - }() - - go func() { - for v := range in { - <-sem - c <- v - } - ticker.Stop() - close(c) - }() - - return c -} |