summaryrefslogblamecommitdiff
path: root/scrape.go
blob: ac75c73880d5dea4b3fdabd1ffd7dcb3e4f45c0d (plain) (tree)
1
2
3
4
5
6
7
8


                 


                 
              
 


                                         

 






                                            
                                                       
                                                       
                                              
                                             


                          




                                                    


      
                                                  


                                                   
                     















                                              
                 



                      
                                                


                                 


                                     


                                          





                                                                           
 
                                       




                                                   
                                  

                   
                                                      










                                                                              




                                                            






                                    
 
























                                                                                               
                                 
                         
 


                                                              

                                   


         
                                              






                                                                         
                       
                                                                   
         
 


                                                             






                                                    
                   

                                                               


                                                    
                 


                                                               
                                   

                                  


                 
                 

 


                                                                          
                       































                                                                                       

         
                   
 






















                                                             
package flyscrape

import (
	"log"
	"strings"
	"sync"
	"time"

	"github.com/PuerkitoBio/goquery"
	"github.com/cornelk/hashmap"
	"github.com/nlnwa/whatwg-url/url"
)

type ScrapeParams struct {
	HTML string
	URL  string
}

type ScrapeOptions struct {
	URL            string   `json:"url"`
	AllowedDomains []string `json:"allowedDomains"`
	BlockedDomains []string `json:"blockedDomains"`
	Depth          int      `json:"depth"`
	Rate           float64  `json:"rate"`
}

type ScrapeResult struct {
	URL       string    `json:"url"`
	Data      any       `json:"data,omitempty"`
	Links     []string  `json:"-"`
	Error     error     `json:"error,omitempty"`
	Timestamp time.Time `json:"timestamp"`
}

type (
	ScrapeFunc func(ScrapeParams) (any, error)
	FetchFunc  func(url string) (string, error)
)

type Scraper struct {
	ScrapeOptions ScrapeOptions
	ScrapeFunc    ScrapeFunc
	FetchFunc     FetchFunc
	Concurrency   int

	visited *hashmap.Map[string, struct{}]
	wg      *sync.WaitGroup
}

type target struct {
	url   string
	depth int
}

type result struct {
	url   string
	data  any
	links []string
	err   error
}

func (s *Scraper) Scrape() <-chan ScrapeResult {
	if s.Concurrency == 0 {
		s.Concurrency = 1
	}
	if s.FetchFunc == nil {
		s.FetchFunc = Fetch()
	}
	if s.ScrapeOptions.Rate == 0 {
		s.ScrapeOptions.Rate = 100
	}
	if len(s.ScrapeOptions.AllowedDomains) == 0 {
		u, err := url.Parse(s.ScrapeOptions.URL)
		if err == nil {
			s.ScrapeOptions.AllowedDomains = []string{u.Host()}
		}
	}

	jobs := make(chan target, 1024)
	results := make(chan result)
	scraperesults := make(chan ScrapeResult)
	s.visited = hashmap.New[string, struct{}]()
	s.wg = &sync.WaitGroup{}

	go s.worker(jobs, results)

	s.wg.Add(1)
	s.visited.Set(s.ScrapeOptions.URL, struct{}{})
	jobs <- target{url: s.ScrapeOptions.URL, depth: s.ScrapeOptions.Depth}

	go func() {
		s.wg.Wait()
		close(jobs)
		close(results)
	}()

	go func() {
		for res := range results {
			scraperesults <- ScrapeResult{
				URL:       res.url,
				Data:      res.data,
				Links:     res.links,
				Error:     res.err,
				Timestamp: time.Now().UTC(),
			}
		}
		close(scraperesults)
	}()

	return scraperesults
}

func (s *Scraper) worker(jobs chan target, results chan<- result) {
	rate := time.Duration(float64(time.Second) / s.ScrapeOptions.Rate)
	for j := range leakychan(jobs, rate) {
		j := j
		go func() {
			res := s.process(j)

			if j.depth > 0 {
				for _, l := range res.links {
					if _, ok := s.visited.Get(l); ok {
						continue
					}

					if !s.isURLAllowed(l) {
						continue
					}

					s.wg.Add(1)
					select {
					case jobs <- target{url: l, depth: j.depth - 1}:
						s.visited.Set(l, struct{}{})
					default:
						log.Println("queue is full, can't add url:", l)
						s.wg.Done()
					}
				}
			}

			if res.err != nil || res.data != nil {
				results <- res
			}
			s.wg.Done()
		}()
	}
}

func (s *Scraper) process(job target) result {
	html, err := s.FetchFunc(job.url)
	if err != nil {
		return result{url: job.url, err: err}
	}

	links := Links(html, job.url)
	data, err := s.ScrapeFunc(ScrapeParams{HTML: html, URL: job.url})
	if err != nil {
		return result{url: job.url, links: links, err: err}
	}

	return result{url: job.url, data: data, links: links}
}

func (s *Scraper) isURLAllowed(rawurl string) bool {
	u, err := url.Parse(rawurl)
	if err != nil {
		return false
	}

	host := u.Host()
	ok := false

	for _, domain := range s.ScrapeOptions.AllowedDomains {
		if domain == "*" || host == domain {
			ok = true
			break
		}
	}

	for _, domain := range s.ScrapeOptions.BlockedDomains {
		if host == domain {
			ok = false
			break
		}
	}

	return ok
}

func Links(html string, origin string) []string {
	var links []string
	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
	if err != nil {
		return nil
	}

	urlParser := url.NewParser(url.WithPercentEncodeSinglePercentSign())

	uniqueLinks := make(map[string]bool)
	doc.Find("a").Each(func(i int, s *goquery.Selection) {
		link, _ := s.Attr("href")

		parsedLink, err := urlParser.ParseRef(origin, link)
		if err != nil || !isValidLink(parsedLink) {
			return
		}

		absLink := parsedLink.Href(true)

		if !uniqueLinks[absLink] {
			links = append(links, absLink)
			uniqueLinks[absLink] = true
		}
	})

	return links
}

func isValidLink(link *url.Url) bool {
	if link.Scheme() != "" && link.Scheme() != "http" && link.Scheme() != "https" {
		return false
	}

	if strings.HasPrefix(link.String(), "javascript:") {
		return false
	}

	return true
}

func leakychan[T any](in chan T, rate time.Duration) chan T {
	ticker := time.NewTicker(rate)
	sem := make(chan struct{}, 1)
	c := make(chan T)

	go func() {
		for range ticker.C {
			sem <- struct{}{}
		}
	}()

	go func() {
		for v := range in {
			<-sem
			c <- v
		}
		ticker.Stop()
		close(c)
	}()

	return c
}