summaryrefslogblamecommitdiff
path: root/scrape.go
blob: 42b3c10a4ec0bf0d3dd89436e9ac2ab2c1567090 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                                      


                 
            
             
                  

                 
              
 


                                         

 




                          
                          




                                                    

 







                                               


                               
 
 

















                                      
                     












                                                                      

 











                                                                          
 


                                           
         


                                          
         
 

                                         
         
 

                                          

         

                                      

         

                                          
         
 
 


                                     
 


                                           
 

                                                                                 

 


                                                                  
 


                                                            
 


                                                              
 


                                                              
 








                                                      
 







                                                                      
                                 
                         
 








                                                                              
                         
                      


         
                                                                       


                                  
                                                                 
                       

                               

         




                                                                 
                       

                               
         
                               
 
                                          
                       

                               

         
                           
 




                                                                                    


                 
              

 


                                                     

         






                                                                 
         

 
                                                      

                                                                          
                       































                                                                                       

         
                   
 
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package flyscrape

import (
	"io"
	"log"
	"net/http"
	"strings"
	"sync"
	"time"

	"github.com/PuerkitoBio/goquery"
	"github.com/cornelk/hashmap"
	"github.com/nlnwa/whatwg-url/url"
)

type ScrapeParams struct {
	HTML string
	URL  string
}

type ScrapeResult struct {
	URL       string    `json:"url"`
	Data      any       `json:"data,omitempty"`
	Links     []string  `json:"-"`
	Error     error     `json:"error,omitempty"`
	Timestamp time.Time `json:"timestamp"`
}

func (s *ScrapeResult) omit() bool {
	return s.Error == nil && s.Data == nil
}

type ScrapeFunc func(ScrapeParams) (any, error)

type FetchFunc func(url string) (string, error)

type Visitor interface {
	Visit(url string)
	MarkVisited(url string)
}

type (
	Request struct {
		URL   string
		Depth int
	}

	Response struct {
		ScrapeResult
		HTML  string
		Visit func(url string)
	}

	target struct {
		url   string
		depth int
	}
)

type Scraper struct {
	ScrapeFunc ScrapeFunc

	opts    Options
	wg      sync.WaitGroup
	jobs    chan target
	visited *hashmap.Map[string, struct{}]
	modules *hashmap.Map[string, Module]

	canRequestHandlers []func(url string, depth int) bool
	onRequestHandlers  []func(*Request)
	onResponseHandlers []func(*Response)
	onCompleteHandlers []func()
	transport          func(*http.Request) (*http.Response, error)
}

func NewScraper() *Scraper {
	s := &Scraper{
		jobs:    make(chan target, 1024),
		visited: hashmap.New[string, struct{}](),
		modules: hashmap.New[string, Module](),
		transport: func(r *http.Request) (*http.Response, error) {
			r.Header.Set("User-Agent", "flyscrape/0.1")
			return http.DefaultClient.Do(r)
		},
	}
	return s
}

func (s *Scraper) LoadModule(mod Module) {
	if v, ok := mod.(Transport); ok {
		s.SetTransport(v.Transport)
	}

	if v, ok := mod.(CanRequest); ok {
		s.CanRequest(v.CanRequest)
	}

	if v, ok := mod.(OnRequest); ok {
		s.OnRequest(v.OnRequest)
	}

	if v, ok := mod.(OnResponse); ok {
		s.OnResponse(v.OnResponse)
	}

	if v, ok := mod.(OnLoad); ok {
		v.OnLoad(s)
	}

	if v, ok := mod.(OnComplete); ok {
		s.OnComplete(v.OnComplete)
	}
}

func (s *Scraper) Visit(url string) {
	s.enqueueJob(url, 0)
}

func (s *Scraper) MarkVisited(url string) {
	s.visited.Insert(url, struct{}{})
}

func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) {
	s.transport = f
}

func (s *Scraper) CanRequest(f func(url string, depth int) bool) {
	s.canRequestHandlers = append(s.canRequestHandlers, f)
}

func (s *Scraper) OnRequest(f func(req *Request)) {
	s.onRequestHandlers = append(s.onRequestHandlers, f)
}

func (s *Scraper) OnResponse(f func(resp *Response)) {
	s.onResponseHandlers = append(s.onResponseHandlers, f)
}

func (s *Scraper) OnComplete(f func()) {
	s.onCompleteHandlers = append(s.onCompleteHandlers, f)
}

func (s *Scraper) Run() {
	go s.worker()
	s.wg.Wait()
	close(s.jobs)

	for _, handler := range s.onCompleteHandlers {
		handler()
	}
}

func (s *Scraper) worker() {
	for job := range s.jobs {
		go func(job target) {
			defer s.wg.Done()

			for _, handler := range s.canRequestHandlers {
				if !handler(job.url, job.depth) {
					return
				}
			}

			res, html := s.process(job)
			for _, handler := range s.onResponseHandlers {
				handler(&Response{
					ScrapeResult: res,
					HTML:         html,
					Visit: func(url string) {
						s.enqueueJob(url, job.depth+1)
					},
				})
			}
		}(job)
	}
}

func (s *Scraper) process(job target) (res ScrapeResult, html string) {
	res.URL = job.url
	res.Timestamp = time.Now()

	req, err := http.NewRequest(http.MethodGet, job.url, nil)
	if err != nil {
		res.Error = err
		return
	}

	for _, handler := range s.onRequestHandlers {
		handler(&Request{URL: job.url, Depth: job.depth})
	}

	resp, err := s.transport(req)
	if err != nil {
		res.Error = err
		return
	}
	defer resp.Body.Close()

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		res.Error = err
		return
	}

	html = string(body)

	if s.ScrapeFunc != nil {
		res.Data, err = s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url})
		if err != nil {
			res.Error = err
			return
		}
	}

	return
}

func (s *Scraper) enqueueJob(url string, depth int) {
	if _, ok := s.visited.Get(url); ok {
		return
	}

	s.wg.Add(1)
	select {
	case s.jobs <- target{url: url, depth: depth}:
		s.MarkVisited(url)
	default:
		log.Println("queue is full, can't add url:", url)
		s.wg.Done()
	}
}

func ParseLinks(html string, origin string) []string {
	var links []string
	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
	if err != nil {
		return nil
	}

	urlParser := url.NewParser(url.WithPercentEncodeSinglePercentSign())

	uniqueLinks := make(map[string]bool)
	doc.Find("a").Each(func(i int, s *goquery.Selection) {
		link, _ := s.Attr("href")

		parsedLink, err := urlParser.ParseRef(origin, link)
		if err != nil || !isValidLink(parsedLink) {
			return
		}

		absLink := parsedLink.Href(true)

		if !uniqueLinks[absLink] {
			links = append(links, absLink)
			uniqueLinks[absLink] = true
		}
	})

	return links
}

func isValidLink(link *url.Url) bool {
	if link.Scheme() != "" && link.Scheme() != "http" && link.Scheme() != "https" {
		return false
	}

	if strings.HasPrefix(link.String(), "javascript:") {
		return false
	}

	return true
}