summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorPhilipp Tanlak <philipp.tanlak@gmail.com>2024-02-07 23:20:55 +0100
committerGitHub <noreply@github.com>2024-02-07 23:20:55 +0100
commit0d6494d164cc490d62473eae0fbd79d5573bb380 (patch)
tree7a4586e89920b6abd4f6c7724f42634d66cf5f68 /modules
parent60139e7de275473332b560b4139a6a01c3da184c (diff)
Add retry module and change rate to requests per minute (#37)v0.7.0
Diffstat (limited to 'modules')
-rw-r--r--modules/ratelimit/ratelimit.go62
-rw-r--r--modules/ratelimit/ratelimit_test.go48
-rw-r--r--modules/retry/retry.go145
-rw-r--r--modules/retry/retry_test.go141
4 files changed, 366 insertions, 30 deletions
diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go
index b23cd7a..152c6fd 100644
--- a/modules/ratelimit/ratelimit.go
+++ b/modules/ratelimit/ratelimit.go
@@ -5,6 +5,7 @@
package ratelimit
import (
+ "math"
"net/http"
"time"
@@ -16,10 +17,12 @@ func init() {
}
type Module struct {
- Rate float64 `json:"rate"`
+ Rate int `json:"rate"`
+ Concurrency int `json:"concurrency"`
- ticker *time.Ticker
- semaphore chan struct{}
+ ticker *time.Ticker
+ ratelimit chan struct{}
+ concurrency chan struct{}
}
func (Module) ModuleInfo() flyscrape.ModuleInfo {
@@ -30,41 +33,54 @@ func (Module) ModuleInfo() flyscrape.ModuleInfo {
}
func (m *Module) Provision(v flyscrape.Context) {
- if m.disabled() {
- return
- }
-
- rate := time.Duration(float64(time.Second) / m.Rate)
+ if m.rateLimitEnabled() {
+ rate := time.Duration(float64(time.Minute) / float64(m.Rate))
+ m.ticker = time.NewTicker(rate)
+ m.ratelimit = make(chan struct{}, int(math.Max(float64(m.Rate)/10, 1)))
- m.ticker = time.NewTicker(rate)
- m.semaphore = make(chan struct{}, 1)
+ go func() {
+ m.ratelimit <- struct{}{}
+ for range m.ticker.C {
+ m.ratelimit <- struct{}{}
+ }
+ }()
+ }
- go func() {
- for range m.ticker.C {
- m.semaphore <- struct{}{}
+ if m.concurrencyEnabled() {
+ m.concurrency = make(chan struct{}, m.Concurrency)
+ for i := 0; i < m.Concurrency; i++ {
+ m.concurrency <- struct{}{}
}
- }()
+ }
}
func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
- if m.disabled() {
- return t
- }
return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
- <-m.semaphore
+ if m.rateLimitEnabled() {
+ <-m.ratelimit
+ }
+
+ if m.concurrencyEnabled() {
+ <-m.concurrency
+ defer func() { m.concurrency <- struct{}{} }()
+ }
+
return t.RoundTrip(r)
})
}
func (m *Module) Finalize() {
- if m.disabled() {
- return
+ if m.rateLimitEnabled() {
+ m.ticker.Stop()
}
- m.ticker.Stop()
}
-func (m *Module) disabled() bool {
- return m.Rate == 0
+func (m *Module) rateLimitEnabled() bool {
+ return m.Rate != 0
+}
+
+func (m *Module) concurrencyEnabled() bool {
+ return m.Concurrency > 0
}
var (
diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go
index 7be29a1..23cc8c8 100644
--- a/modules/ratelimit/ratelimit_test.go
+++ b/modules/ratelimit/ratelimit_test.go
@@ -32,7 +32,7 @@ func TestRatelimit(t *testing.T) {
},
},
&ratelimit.Module{
- Rate: 100,
+ Rate: 240,
},
}
@@ -41,12 +41,46 @@ func TestRatelimit(t *testing.T) {
scraper.Modules = mods
scraper.Run()
- first := times[0].Add(-10 * time.Millisecond)
- second := times[1].Add(-20 * time.Millisecond)
+ first := times[0].Add(-250 * time.Millisecond)
+ second := times[1].Add(-500 * time.Millisecond)
- require.Less(t, first.Sub(start), 2*time.Millisecond)
- require.Less(t, second.Sub(start), 2*time.Millisecond)
+ require.Less(t, first.Sub(start), 250*time.Millisecond)
+ require.Less(t, second.Sub(start), 250*time.Millisecond)
- require.Less(t, start.Sub(first), 2*time.Millisecond)
- require.Less(t, start.Sub(second), 2*time.Millisecond)
+ require.Less(t, start.Sub(first), 250*time.Millisecond)
+ require.Less(t, start.Sub(second), 250*time.Millisecond)
+}
+
+func TestRatelimitConcurrency(t *testing.T) {
+ var times []time.Time
+
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ times = append(times, time.Now())
+ time.Sleep(10 * time.Millisecond)
+ return flyscrape.MockResponse(200, `
+ <a href="foo"></a>
+ <a href="bar"></a>
+ <a href="baz"></a>
+ <a href="qux"></a>
+ `)
+ })
+ },
+ },
+ &ratelimit.Module{
+ Concurrency: 2,
+ },
+ }
+
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
+ scraper.Run()
+
+ require.Len(t, times, 5)
+ require.Less(t, times[2].Sub(times[1]), time.Millisecond)
+ require.Less(t, times[4].Sub(times[3]), time.Millisecond)
}
diff --git a/modules/retry/retry.go b/modules/retry/retry.go
new file mode 100644
index 0000000..9c00275
--- /dev/null
+++ b/modules/retry/retry.go
@@ -0,0 +1,145 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package retry
+
+import (
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "slices"
+ "strconv"
+ "time"
+
+ "github.com/philippta/flyscrape"
+)
+
+func init() {
+ flyscrape.RegisterModule(Module{})
+}
+
+type Module struct {
+ ticker *time.Ticker
+ semaphore chan struct{}
+
+ RetryDelays []time.Duration
+}
+
+func (Module) ModuleInfo() flyscrape.ModuleInfo {
+ return flyscrape.ModuleInfo{
+ ID: "retry",
+ New: func() flyscrape.Module { return new(Module) },
+ }
+}
+
+func (m *Module) Provision(flyscrape.Context) {
+ if m.RetryDelays == nil {
+ m.RetryDelays = defaultRetryDelays
+ }
+}
+
+func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ resp, err := t.RoundTrip(r)
+ if !shouldRetry(resp, err) {
+ return resp, err
+ }
+
+ for _, delay := range m.RetryDelays {
+ drainBody(resp, err)
+
+ time.Sleep(retryAfter(resp, delay))
+
+ resp, err = t.RoundTrip(r)
+ if !shouldRetry(resp, err) {
+ break
+ }
+ }
+
+ return resp, err
+ })
+}
+
+func shouldRetry(resp *http.Response, err error) bool {
+ statusCodes := []int{
+ http.StatusRequestTimeout,
+ http.StatusTooEarly,
+ http.StatusTooManyRequests,
+ http.StatusInternalServerError,
+ http.StatusBadGateway,
+ http.StatusServiceUnavailable,
+ http.StatusGatewayTimeout,
+ }
+
+ if resp != nil {
+ if slices.Contains(statusCodes, resp.StatusCode) {
+ return true
+ }
+ }
+ if err == nil {
+ return false
+ }
+ if _, ok := err.(net.Error); ok {
+ return true
+ }
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ return true
+ }
+
+ return false
+}
+
+func drainBody(resp *http.Response, err error) {
+ if err == nil && resp != nil && resp.Body != nil {
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ }
+}
+
+func retryAfter(resp *http.Response, fallback time.Duration) time.Duration {
+ if resp == nil {
+ return fallback
+ }
+
+ timeexp := resp.Header.Get("Retry-After")
+ if timeexp == "" {
+ return fallback
+ }
+
+ if seconds, err := strconv.Atoi(timeexp); err == nil {
+ return time.Duration(seconds) * time.Second
+ }
+
+ formats := []string{
+ time.RFC1123, // HTTP Spec
+ time.RFC1123Z,
+ time.ANSIC,
+ time.UnixDate,
+ time.RubyDate,
+ time.RFC822,
+ time.RFC822Z,
+ time.RFC850,
+ time.RFC3339,
+ }
+ for _, format := range formats {
+ if t, err := time.Parse(format, timeexp); err == nil {
+ return t.Sub(time.Now())
+ }
+ }
+
+ return fallback
+}
+
+var defaultRetryDelays = []time.Duration{
+ 1 * time.Second,
+ 2 * time.Second,
+ 5 * time.Second,
+ 10 * time.Second,
+}
+
+var (
+ _ flyscrape.TransportAdapter = (*Module)(nil)
+ _ flyscrape.Provisioner = (*Module)(nil)
+)
diff --git a/modules/retry/retry_test.go b/modules/retry/retry_test.go
new file mode 100644
index 0000000..b979320
--- /dev/null
+++ b/modules/retry/retry_test.go
@@ -0,0 +1,141 @@
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package retry_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/philippta/flyscrape"
+ "github.com/philippta/flyscrape/modules/followlinks"
+ "github.com/philippta/flyscrape/modules/hook"
+ "github.com/philippta/flyscrape/modules/retry"
+ "github.com/philippta/flyscrape/modules/starturl"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRetry(t *testing.T) {
+ t.Parallel()
+ var count int
+
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ count++
+ return flyscrape.MockResponse(http.StatusServiceUnavailable, "service unavailable")
+ })
+ },
+ },
+ &retry.Module{
+ RetryDelays: []time.Duration{
+ 100 * time.Millisecond,
+ 200 * time.Millisecond,
+ },
+ },
+ }
+
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
+ scraper.Run()
+
+ require.Equal(t, 3, count)
+}
+
+func TestRetryStatusCodes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ statusCode int
+ retry bool
+ }{
+ {statusCode: http.StatusBadGateway, retry: true},
+ {statusCode: http.StatusTooManyRequests, retry: true},
+ {statusCode: http.StatusBadRequest, retry: false},
+ {statusCode: http.StatusOK, retry: false},
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s_%t", http.StatusText(test.statusCode), test.retry), func(t *testing.T) {
+ t.Parallel()
+ var count int
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ count++
+ return flyscrape.MockResponse(test.statusCode, http.StatusText(test.statusCode))
+ })
+ },
+ },
+ &retry.Module{
+ RetryDelays: []time.Duration{
+ 100 * time.Millisecond,
+ 200 * time.Millisecond,
+ },
+ },
+ }
+
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
+ scraper.Run()
+
+ if test.retry {
+ require.NotEqual(t, 1, count)
+ } else {
+ require.Equal(t, 1, count)
+ }
+ })
+ }
+}
+
+func TestRetryErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ error error
+ }{
+ {error: &net.OpError{}},
+ {error: io.ErrUnexpectedEOF},
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%T", test.error), func(t *testing.T) {
+ t.Parallel()
+ var count int
+ mods := []flyscrape.Module{
+ &starturl.Module{URL: "http://www.example.com"},
+ &followlinks.Module{},
+ hook.Module{
+ AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper {
+ return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
+ return nil, test.error
+ })
+ },
+ },
+ &retry.Module{
+ RetryDelays: []time.Duration{
+ 100 * time.Millisecond,
+ 200 * time.Millisecond,
+ },
+ },
+ }
+
+ scraper := flyscrape.NewScraper()
+ scraper.Modules = mods
+ scraper.Run()
+
+ require.NotEqual(t, 1, count)
+ })
+ }
+}