// This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package flyscrape import ( "io" "log" "net/http" "strings" "sync" "time" "github.com/PuerkitoBio/goquery" "github.com/cornelk/hashmap" "github.com/nlnwa/whatwg-url/url" ) type ScrapeParams struct { HTML string URL string } type ScrapeResult struct { URL string `json:"url"` Data any `json:"data,omitempty"` Links []string `json:"-"` Error error `json:"error,omitempty"` Timestamp time.Time `json:"timestamp"` } func (s *ScrapeResult) omit() bool { return s.Error == nil && s.Data == nil } type ScrapeFunc func(ScrapeParams) (any, error) type FetchFunc func(url string) (string, error) type Visitor interface { Visit(url string) MarkVisited(url string) } type ( Request struct { URL string Depth int } Response struct { ScrapeResult HTML string Visit func(url string) } target struct { url string depth int } ) type Scraper struct { ScrapeFunc ScrapeFunc opts Options wg sync.WaitGroup jobs chan target visited *hashmap.Map[string, struct{}] modules *hashmap.Map[string, Module] canRequestHandlers []func(url string, depth int) bool onRequestHandlers []func(*Request) onResponseHandlers []func(*Response) onCompleteHandlers []func() transport func(*http.Request) (*http.Response, error) } func NewScraper() *Scraper { s := &Scraper{ jobs: make(chan target, 1024), visited: hashmap.New[string, struct{}](), modules: hashmap.New[string, Module](), transport: func(r *http.Request) (*http.Response, error) { r.Header.Set("User-Agent", "flyscrape/0.1") return http.DefaultClient.Do(r) }, } return s } func (s *Scraper) LoadModule(mod Module) { if v, ok := mod.(Transport); ok { s.SetTransport(v.Transport) } if v, ok := mod.(CanRequest); ok { s.CanRequest(v.CanRequest) } if v, ok := mod.(OnRequest); ok { s.OnRequest(v.OnRequest) } if v, ok := mod.(OnResponse); ok { s.OnResponse(v.OnResponse) } if v, ok := mod.(OnLoad); ok { v.OnLoad(s) } if v, ok := mod.(OnComplete); ok { s.OnComplete(v.OnComplete) } } func (s *Scraper) Visit(url string) { s.enqueueJob(url, 0) } func (s *Scraper) MarkVisited(url string) { s.visited.Insert(url, struct{}{}) } func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) { s.transport = f } func (s *Scraper) CanRequest(f func(url string, depth int) bool) { s.canRequestHandlers = append(s.canRequestHandlers, f) } func (s *Scraper) OnRequest(f func(req *Request)) { s.onRequestHandlers = append(s.onRequestHandlers, f) } func (s *Scraper) OnResponse(f func(resp *Response)) { s.onResponseHandlers = append(s.onResponseHandlers, f) } func (s *Scraper) OnComplete(f func()) { s.onCompleteHandlers = append(s.onCompleteHandlers, f) } func (s *Scraper) Run() { go s.worker() s.wg.Wait() close(s.jobs) for _, handler := range s.onCompleteHandlers { handler() } } func (s *Scraper) worker() { for job := range s.jobs { go func(job target) { defer s.wg.Done() for _, handler := range s.canRequestHandlers { if !handler(job.url, job.depth) { return } } res, html := s.process(job) for _, handler := range s.onResponseHandlers { handler(&Response{ ScrapeResult: res, HTML: html, Visit: func(url string) { s.enqueueJob(url, job.depth+1) }, }) } }(job) } } func (s *Scraper) process(job target) (res ScrapeResult, html string) { res.URL = job.url res.Timestamp = time.Now() req, err := http.NewRequest(http.MethodGet, job.url, nil) if err != nil { res.Error = err return } for _, handler := range s.onRequestHandlers { handler(&Request{URL: job.url, Depth: job.depth}) } resp, err := s.transport(req) if err != nil { res.Error = err return } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { res.Error = err return } html = string(body) if s.ScrapeFunc != nil { res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url}) if err != nil { res.Error = err return } } return } func (s *Scraper) enqueueJob(url string, depth int) { if _, ok := s.visited.Get(url); ok { return } s.wg.Add(1) select { case s.jobs <- target{url: url, depth: depth}: s.MarkVisited(url) default: log.Println("queue is full, can't add url:", url) s.wg.Done() } } func ParseLinks(html string, origin string) []string { var links []string doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) if err != nil { return nil } urlParser := url.NewParser(url.WithPercentEncodeSinglePercentSign()) uniqueLinks := make(map[string]bool) doc.Find("a").Each(func(i int, s *goquery.Selection) { link, _ := s.Attr("href") parsedLink, err := urlParser.ParseRef(origin, link) if err != nil || !isValidLink(parsedLink) { return } absLink := parsedLink.Href(true) if !uniqueLinks[absLink] { links = append(links, absLink) uniqueLinks[absLink] = true } }) return links } func isValidLink(link *url.Url) bool { if link.Scheme() != "" && link.Scheme() != "http" && link.Scheme() != "https" { return false } if strings.HasPrefix(link.String(), "javascript:") { return false } return true }