summaryrefslogtreecommitdiff
path: root/scrape.go
diff options
context:
space:
mode:
Diffstat (limited to 'scrape.go')
-rw-r--r--scrape.go85
1 files changed, 58 insertions, 27 deletions
diff --git a/scrape.go b/scrape.go
index 3ecfe7b..6ff92dc 100644
--- a/scrape.go
+++ b/scrape.go
@@ -4,6 +4,7 @@ import (
"log"
"strings"
"sync"
+ "time"
"github.com/PuerkitoBio/goquery"
"github.com/cornelk/hashmap"
@@ -19,6 +20,7 @@ type ScrapeOptions struct {
URL string `json:"url"`
AllowedDomains []string `json:"allowedDomains"`
Depth int `json:"depth"`
+ Rate float64 `json:"rate"`
}
type ScrapeResult struct {
@@ -62,6 +64,9 @@ func (s *Scraper) Scrape() <-chan ScrapeResult {
if s.FetchFunc == nil {
s.FetchFunc = Fetch()
}
+ if s.ScrapeOptions.Rate == 0 {
+ s.ScrapeOptions.Rate = 100
+ }
if len(s.ScrapeOptions.AllowedDomains) == 0 {
u, err := url.Parse(s.ScrapeOptions.URL)
if err == nil {
@@ -75,11 +80,10 @@ func (s *Scraper) Scrape() <-chan ScrapeResult {
s.visited = hashmap.New[string, struct{}]()
s.wg = &sync.WaitGroup{}
- for i := 0; i < s.Concurrency; i++ {
- go s.worker(i, jobs, results)
- }
+ go s.worker(jobs, results)
s.wg.Add(1)
+ s.visited.Set(s.ScrapeOptions.URL, struct{}{})
jobs <- target{url: s.ScrapeOptions.URL, depth: s.ScrapeOptions.Depth}
go func() {
@@ -103,33 +107,37 @@ func (s *Scraper) Scrape() <-chan ScrapeResult {
return scraperesults
}
-func (s *Scraper) worker(id int, jobs chan target, results chan<- result) {
- for j := range jobs {
- res := s.process(j)
-
- if j.depth > 0 {
- for _, l := range res.links {
- if _, ok := s.visited.Get(l); ok {
- continue
- }
-
- if !s.isURLAllowed(l) {
- continue
- }
-
- s.wg.Add(1)
- select {
- case jobs <- target{url: l, depth: j.depth - 1}:
- s.visited.Set(l, struct{}{})
- default:
- log.Println("queue is full, can't add url:", l)
- s.wg.Done()
+func (s *Scraper) worker(jobs chan target, results chan<- result) {
+ rate := time.Duration(float64(time.Second) / s.ScrapeOptions.Rate)
+ for j := range leakychan(jobs, rate) {
+ j := j
+ go func() {
+ res := s.process(j)
+
+ if j.depth > 0 {
+ for _, l := range res.links {
+ if _, ok := s.visited.Get(l); ok {
+ continue
+ }
+
+ if !s.isURLAllowed(l) {
+ continue
+ }
+
+ s.wg.Add(1)
+ select {
+ case jobs <- target{url: l, depth: j.depth - 1}:
+ s.visited.Set(l, struct{}{})
+ default:
+ log.Println("queue is full, can't add url:", l)
+ s.wg.Done()
+ }
}
}
- }
- results <- res
- s.wg.Done()
+ results <- res
+ s.wg.Done()
+ }()
}
}
@@ -208,3 +216,26 @@ func isValidLink(link *url.Url) bool {
return true
}
+
+func leakychan[T any](in chan T, rate time.Duration) chan T {
+ ticker := time.NewTicker(rate)
+ sem := make(chan struct{}, 1)
+ c := make(chan T)
+
+ go func() {
+ for range ticker.C {
+ sem <- struct{}{}
+ }
+ }()
+
+ go func() {
+ for v := range in {
+ <-sem
+ c <- v
+ }
+ ticker.Stop()
+ close(c)
+ }()
+
+ return c
+}