summaryrefslogblamecommitdiff
path: root/scrape.go
blob: 4186247f2e96cc005660e929837f37a518bcea18 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                                      


                 
            
             
                  

                            

                 

                       
 

                                        

 




                          



                                               


                               
 
 






                              
 






                              
 






                              
 
                     

                             





                                                





                                                                      

 
                            
                                    







                                                                          
                               


                
 


                                           
         


                                          
         
 

                                         
         
 

                                          

         

                                      

         

                                          
         
 
 


                                     
 


                                           
 

                                                                                 

 


                                                                  
 


                                                            
 


                                                              
 


                                                              
 








                                                      
 







                                                                      
                                 
                         
 
                                                     
                      


         



















                                                              
 
                                                                     
                       
                                    
                      
         
                                    
 
                                                     
                                


                                     
                       
                                    
                      
         
                               
 



                                                  
                       
                                    
                      

         
                                
                                                                                                              
                               
                                            
                              

                 

 


                                                     

         






                                                                 
         

 
                                                      

                                                                          
                       


                          



                                           




                                                              

                                                        



                                                           
                                              









                                                      

                                                                                 
                            

         
                   
 
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package flyscrape

import (
	"io"
	"log"
	"net/http"
	"net/http/cookiejar"
	"net/url"
	"strings"
	"sync"

	gourl "net/url"

	"github.com/PuerkitoBio/goquery"
	"github.com/cornelk/hashmap"
)

type ScrapeParams struct {
	HTML string
	URL  string
}

type ScrapeFunc func(ScrapeParams) (any, error)

type FetchFunc func(url string) (string, error)

type Visitor interface {
	Visit(url string)
	MarkVisited(url string)
}

type Request struct {
	Method  string
	URL     string
	Headers http.Header
	Cookies http.CookieJar
	Depth   int
}

type Response struct {
	StatusCode int
	Headers    http.Header
	Body       []byte
	Data       any
	Error      error
	Request    *Request

	Visit func(url string)
}

type target struct {
	url   string
	depth int
}

type Scraper struct {
	ScrapeFunc ScrapeFunc

	cfg       Config
	wg        sync.WaitGroup
	jobs      chan target
	visited   *hashmap.Map[string, struct{}]
	modules   *hashmap.Map[string, Module]
	cookieJar *cookiejar.Jar

	canRequestHandlers []func(url string, depth int) bool
	onRequestHandlers  []func(*Request)
	onResponseHandlers []func(*Response)
	onCompleteHandlers []func()
	transport          func(*http.Request) (*http.Response, error)
}

func NewScraper() *Scraper {
	jar, _ := cookiejar.New(nil)
	s := &Scraper{
		jobs:    make(chan target, 1024),
		visited: hashmap.New[string, struct{}](),
		modules: hashmap.New[string, Module](),
		transport: func(r *http.Request) (*http.Response, error) {
			r.Header.Set("User-Agent", "flyscrape/0.1")
			return http.DefaultClient.Do(r)
		},
		cookieJar: jar,
	}
	return s
}

func (s *Scraper) LoadModule(mod Module) {
	if v, ok := mod.(Transport); ok {
		s.SetTransport(v.Transport)
	}

	if v, ok := mod.(CanRequest); ok {
		s.CanRequest(v.CanRequest)
	}

	if v, ok := mod.(OnRequest); ok {
		s.OnRequest(v.OnRequest)
	}

	if v, ok := mod.(OnResponse); ok {
		s.OnResponse(v.OnResponse)
	}

	if v, ok := mod.(OnLoad); ok {
		v.OnLoad(s)
	}

	if v, ok := mod.(OnComplete); ok {
		s.OnComplete(v.OnComplete)
	}
}

func (s *Scraper) Visit(url string) {
	s.enqueueJob(url, 0)
}

func (s *Scraper) MarkVisited(url string) {
	s.visited.Insert(url, struct{}{})
}

func (s *Scraper) SetTransport(f func(r *http.Request) (*http.Response, error)) {
	s.transport = f
}

func (s *Scraper) CanRequest(f func(url string, depth int) bool) {
	s.canRequestHandlers = append(s.canRequestHandlers, f)
}

func (s *Scraper) OnRequest(f func(req *Request)) {
	s.onRequestHandlers = append(s.onRequestHandlers, f)
}

func (s *Scraper) OnResponse(f func(resp *Response)) {
	s.onResponseHandlers = append(s.onResponseHandlers, f)
}

func (s *Scraper) OnComplete(f func()) {
	s.onCompleteHandlers = append(s.onCompleteHandlers, f)
}

func (s *Scraper) Run() {
	go s.worker()
	s.wg.Wait()
	close(s.jobs)

	for _, handler := range s.onCompleteHandlers {
		handler()
	}
}

func (s *Scraper) worker() {
	for job := range s.jobs {
		go func(job target) {
			defer s.wg.Done()

			for _, handler := range s.canRequestHandlers {
				if !handler(job.url, job.depth) {
					return
				}
			}

			s.process(job.url, job.depth)
		}(job)
	}
}

func (s *Scraper) process(url string, depth int) {
	request := &Request{
		Method:  http.MethodGet,
		URL:     url,
		Headers: http.Header{},
		Cookies: s.cookieJar,
	}

	response := &Response{
		Request: request,
		Visit: func(url string) {
			s.enqueueJob(url, depth+1)
		},
	}

	defer func() {
		for _, handler := range s.onResponseHandlers {
			handler(response)
		}
	}()

	req, err := http.NewRequest(request.Method, request.URL, nil)
	if err != nil {
		response.Error = err
		return
	}
	req.Header = request.Headers

	for _, handler := range s.onRequestHandlers {
		handler(request)
	}

	resp, err := s.transport(req)
	if err != nil {
		response.Error = err
		return
	}
	defer resp.Body.Close()

	response.StatusCode = resp.StatusCode
	response.Headers = resp.Header

	response.Body, err = io.ReadAll(resp.Body)
	if err != nil {
		response.Error = err
		return
	}

	if s.ScrapeFunc != nil {
		response.Data, err = s.ScrapeFunc(ScrapeParams{HTML: string(response.Body), URL: request.URL})
		if err != nil {
			response.Error = err
			return
		}
	}
}

func (s *Scraper) enqueueJob(url string, depth int) {
	if _, ok := s.visited.Get(url); ok {
		return
	}

	s.wg.Add(1)
	select {
	case s.jobs <- target{url: url, depth: depth}:
		s.MarkVisited(url)
	default:
		log.Println("queue is full, can't add url:", url)
		s.wg.Done()
	}
}

func ParseLinks(html string, origin string) []string {
	var links []string
	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
	if err != nil {
		return nil
	}

	originurl, err := url.Parse(origin)
	if err != nil {
		return nil
	}

	uniqueLinks := make(map[string]bool)
	doc.Find("a").Each(func(i int, s *goquery.Selection) {
		link, _ := s.Attr("href")

		parsedLink, err := originurl.Parse(link)

		if err != nil || !isValidLink(parsedLink) {
			return
		}

		absLink := parsedLink.String()

		if !uniqueLinks[absLink] {
			links = append(links, absLink)
			uniqueLinks[absLink] = true
		}
	})

	return links
}

func isValidLink(link *gourl.URL) bool {
	if link.Scheme != "" && link.Scheme != "http" && link.Scheme != "https" {
		return false
	}

	return true
}