// This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package flyscrape import ( "io" "log" "net/http" "net/http/cookiejar" "net/url" "strings" "sync" gourl "net/url" "github.com/PuerkitoBio/goquery" "github.com/cornelk/hashmap" ) type ScrapeParams struct { HTML string URL string } type ScrapeFunc func(ScrapeParams) (any, error) type FetchFunc func(url string) (string, error) type Visitor interface { Visit(url string) MarkVisited(url string) } type Request struct { Method string URL string Headers http.Header Cookies http.CookieJar Depth int } type Response struct { StatusCode int Headers http.Header Body []byte Data any Error error Request *Request Visit func(url string) } type target struct { url string depth int } type Scraper struct { ScrapeFunc ScrapeFunc cfg Config wg sync.WaitGroup jobs chan target visited *hashmap.Map[string, struct{}] modules *hashmap.Map[string, Module] cookieJar *cookiejar.Jar canRequestHandlers []func(url string, depth int) bool onRequestHandlers []func(*Request) onResponseHandlers []func(*Response) onCompleteHandlers []func() transport func(*http.Request) (*http.Response, error) } func NewScraper() *Scraper { jar, _ := cookiejar.New(nil) s := &Scraper{ jobs: make(chan target, 1024), visited: hashmap.New[string, struct{}](), modules: hashmap.New[string, Module](), transport: func(r *http.Request) (*http.Response, error) { r.Header.Set("User-Agent", "flyscrape/0.1") return http.DefaultClient.Do(r) }, cookieJar: jar, } return s } func (s *Scraper) LoadModule(mod Module) { if v, ok := mod.(Transport); ok { s.SetTransport(v.Transport) } if v, ok := mod.(CanRequest); ok { s.CanRequest(v.CanRequest) } if v, ok := mod.(OnRequest); ok { s.OnRequest(v.OnRequest) } if v, ok := mod.(OnResponse); ok { s.OnResponse(v.OnResponse) } if v, ok := mod.(OnLoad); ok { v.OnLoad(s) } if v, ok := mod.(OnComplete); ok { s.OnComplete(v.OnComplete) } } func (s *Scraper) Visit(url string) { s.enqueueJob(url, 0) } func (s *Scraper) MarkVisited(url string) { s.visited.Insert(url, struct{}{}) } func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) { s.transport = f } func (s *Scraper) CanRequest(f func(url string, depth int) bool) { s.canRequestHandlers = append(s.canRequestHandlers, f) } func (s *Scraper) OnRequest(f func(req *Request)) { s.onRequestHandlers = append(s.onRequestHandlers, f) } func (s *Scraper) OnResponse(f func(resp *Response)) { s.onResponseHandlers = append(s.onResponseHandlers, f) } func (s *Scraper) OnComplete(f func()) { s.onCompleteHandlers = append(s.onCompleteHandlers, f) } func (s *Scraper) Run() { go s.worker() s.wg.Wait() close(s.jobs) for _, handler := range s.onCompleteHandlers { handler() } } func (s *Scraper) worker() { for job := range s.jobs { go func(job target) { defer s.wg.Done() for _, handler := range s.canRequestHandlers { if !handler(job.url, job.depth) { return } } s.process(job.url, job.depth) }(job) } } func (s *Scraper) process(url string, depth int) { request := &Request{ Method: http.MethodGet, URL: url, Headers: http.Header{}, Cookies: s.cookieJar, } response := &Response{ Request: request, Visit: func(url string) { s.enqueueJob(url, depth+1) }, } defer func() { for _, handler := range s.onResponseHandlers { handler(response) } }() req, err := http.NewRequest(request.Method, request.URL, nil) if err != nil { response.Error = err return } req.Header = request.Headers for _, handler := range s.onRequestHandlers { handler(request) } resp, err := s.transport(req) if err != nil { response.Error = err return } defer resp.Body.Close() response.StatusCode = resp.StatusCode response.Headers = resp.Header response.Body, err = io.ReadAll(resp.Body) if err != nil { response.Error = err return } if s.ScrapeFunc != nil { response.Data, err = s.ScrapeFunc(ScrapeParams{HTML: string(response.Body), URL: request.URL}) if err != nil { response.Error = err return } } } func (s *Scraper) enqueueJob(url string, depth int) { if _, ok := s.visited.Get(url); ok { return } s.wg.Add(1) select { case s.jobs <- target{url: url, depth: depth}: s.MarkVisited(url) default: log.Println("queue is full, can't add url:", url) s.wg.Done() } } func ParseLinks(html string, origin string) []string { var links []string doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) if err != nil { return nil } originurl, err := url.Parse(origin) if err != nil { return nil } uniqueLinks := make(map[string]bool) doc.Find("a").Each(func(i int, s *goquery.Selection) { link, _ := s.Attr("href") parsedLink, err := originurl.Parse(link) if err != nil || !isValidLink(parsedLink) { return } absLink := parsedLink.String() if !uniqueLinks[absLink] { links = append(links, absLink) uniqueLinks[absLink] = true } }) return links } func isValidLink(link *gourl.URL) bool { if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" { return false } return true }