summaryrefslogtreecommitdiff
path: root/modules/ratelimit
diff options
context:
space:
mode:
authorPhilipp Tanlak <philipp.tanlak@gmail.com>2024-02-07 23:20:55 +0100
committerGitHub <noreply@github.com>2024-02-07 23:20:55 +0100
commit0d6494d164cc490d62473eae0fbd79d5573bb380 (patch)
tree7a4586e89920b6abd4f6c7724f42634d66cf5f68 /modules/ratelimit
parent60139e7de275473332b560b4139a6a01c3da184c (diff)
Add retry module and change rate to requests per minute (#37)v0.7.0
Diffstat (limited to 'modules/ratelimit')
-rw-r--r--modules/ratelimit/ratelimit.go62
-rw-r--r--modules/ratelimit/ratelimit_test.go48
2 files changed, 80 insertions, 30 deletions
diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go
index b23cd7a..152c6fd 100644
--- a/modules/ratelimit/ratelimit.go
+++ b/modules/ratelimit/ratelimit.go
@@ -5,6 +5,7 @@
package ratelimit
import (
+ "math"
"net/http"
"time"
@@ -16,10 +17,12 @@ func init() {
}
type Module struct {
- Rate float64 `json:"rate"`
+ Rate int `json:"rate"`
+ Concurrency int `json:"concurrency"`
- ticker *time.Ticker
- semaphore chan struct{}
+ ticker *time.Ticker
+ ratelimit chan struct{}
+ concurrency chan struct{}
}
func (Module) ModuleInfo() flyscrape.ModuleInfo {
@@ -30,41 +33,54 @@ func (Module) ModuleInfo() flyscrape.ModuleInfo {
}
func (m *Module) Provision(v flyscrape.Context) {
- if m.disabled() {
- return
- }
-
- rate := time.Duration(float64(time.Second) / m.Rate)
+ if m.rateLimitEnabled() {
+ rate := time.Duration(float64(time.Minute) / float64(m.Rate))
+ m.ticker = time.NewTicker(rate)
+ m.ratelimit = make(chan struct{}, int(math.Max(float64(m.Rate)/10, 1)))
- m.ticker = time.NewTicker(rate)
- m.semaphore = make(chan struct{}, 1)
+ go func() {
+ m.ratelimit <- struct{}{}
+ for range m.ticker.C {
+ m.ratelimit <- struct{}{}
+ }
+ }()
+ }
- go func() {
- for range m.ticker.C {
- m.semaphore <- struct{}{}
+ if m.concurrencyEnabled() {
+ m.concurrency = make(chan struct{}, m.Concurrency)
+ for i := 0; i < m.Concurrency; i++ {
+ m.concurrency <- struct{}{}
}
- }()
+ }
}
func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
- if m.disabled() {
- return t
- }
return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
- <-m.semaphore
+ if m.rateLimitEnabled() {
+ <-m.ratelimit
+ }
+
+ if m.concurrencyEnabled() {
+ <-m.concurrency
+ defer func() { m.concurrency <- struct{}{} }()
+ }
+
return t.RoundTrip(r)
})
}
func (m *Module) Finalize() {
- if m.disabled() {
- return
+ if m.rateLimitEnabled() {
+ m.ticker.Stop()
}
- m.ticker.Stop()
}
-func (m *Module) disabled() bool {
- return m.Rate == 0
+func (m *Module) rateLimitEnabled() bool {
+ return m.Rate != 0
+}
+
+func (m *Module) concurrencyEnabled() bool {
+ return m.Concurrency > 0
}
var (
diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go
index 7be29a1..23cc8c8 100644
--- a/modules/ratelimit/ratelimit_test.go
+++ b/modules/ratelimit/ratelimit_test.go
@@ -32,7 +32,7 @@ func TestRatelimit(t *testing.T) {
},
},
&ratelimit.Module{
- Rate: 100,
+ Rate: 240,
},
}
@@ -41,12 +41,46 @@ func TestRatelimit(t *testing.T) {
scraper.Modules = mods
scraper.Run()
- first := times[0].Add(-10 * time.Millisecond)
- second := times[1].Add(-20 * time.Millisecond)
+ first := times[0].Add(-250 * time.Millisecond)
+ second := times[1].Add(-500 * time.Millisecond)
- require.Less(t, first.Sub(start), 2*time.Millisecond)
- require.Less(t, second.Sub(start), 2*time.Millisecond)
+ require.Less(t, first.Sub(start), 250*time.Millisecond)
+ require.Less(t, second.Sub(start), 250*time.Millisecond)
- require.Less(t, start.Sub(first), 2*time.Millisecond)
- require.Less(t, start.Sub(second), 2*time.Millisecond)
+ require.Less(t, start.Sub(first), 250*time.Millisecond)
+ require.Less(t, start.Sub(second), 250*time.Millisecond)
+}
+
+func TestRatelimitConcurrency(t *testing.T) {
+ var times []time.Time
+
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ times = append(times, time.Now())
+ time.Sleep(10 * time.Millisecond)
+ return flyscrape.MockResponse(200, `
+ <a href="foo"></a>
+ <a href="bar"></a>
+ <a href="baz"></a>
+ <a href="qux"></a>
+ `)
+ })
+ },
+ },
+ &ratelimit.Module{
+ Concurrency: 2,
+ },
+ }
+
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
+ scraper.Run()
+
+ require.Len(t, times, 5)
+ require.Less(t, times[2].Sub(times[1]), time.Millisecond)
+ require.Less(t, times[4].Sub(times[3]), time.Millisecond)
}