diff options
Diffstat (limited to 'modules')
| -rw-r--r-- | modules/ratelimit/ratelimit.go | 62 | ||||
| -rw-r--r-- | modules/ratelimit/ratelimit_test.go | 48 | ||||
| -rw-r--r-- | modules/retry/retry.go | 145 | ||||
| -rw-r--r-- | modules/retry/retry_test.go | 141 |
4 files changed, 366 insertions, 30 deletions
diff --git a/modules/ratelimit/ratelimit.go b/modules/ratelimit/ratelimit.go index b23cd7a..152c6fd 100644 --- a/modules/ratelimit/ratelimit.go +++ b/modules/ratelimit/ratelimit.go @@ -5,6 +5,7 @@ package ratelimit import ( + "math" "net/http" "time" @@ -16,10 +17,12 @@ func init() { } type Module struct { - Rate float64 `json:"rate"` + Rate int `json:"rate"` + Concurrency int `json:"concurrency"` - ticker *time.Ticker - semaphore chan struct{} + ticker *time.Ticker + ratelimit chan struct{} + concurrency chan struct{} } func (Module) ModuleInfo() flyscrape.ModuleInfo { @@ -30,41 +33,54 @@ func (Module) ModuleInfo() flyscrape.ModuleInfo { } func (m *Module) Provision(v flyscrape.Context) { - if m.disabled() { - return - } - - rate := time.Duration(float64(time.Second) / m.Rate) + if m.rateLimitEnabled() { + rate := time.Duration(float64(time.Minute) / float64(m.Rate)) + m.ticker = time.NewTicker(rate) + m.ratelimit = make(chan struct{}, int(math.Max(float64(m.Rate)/10, 1))) - m.ticker = time.NewTicker(rate) - m.semaphore = make(chan struct{}, 1) + go func() { + m.ratelimit <- struct{}{} + for range m.ticker.C { + m.ratelimit <- struct{}{} + } + }() + } - go func() { - for range m.ticker.C { - m.semaphore <- struct{}{} + if m.concurrencyEnabled() { + m.concurrency = make(chan struct{}, m.Concurrency) + for i := 0; i < m.Concurrency; i++ { + m.concurrency <- struct{}{} } - }() + } } func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { - if m.disabled() { - return t - } return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - <-m.semaphore + if m.rateLimitEnabled() { + <-m.ratelimit + } + + if m.concurrencyEnabled() { + <-m.concurrency + defer func() { m.concurrency <- struct{}{} }() + } + return t.RoundTrip(r) }) } func (m *Module) Finalize() { - if m.disabled() { - return + if m.rateLimitEnabled() { + m.ticker.Stop() } - m.ticker.Stop() } -func (m *Module) disabled() bool { - return m.Rate == 0 +func (m *Module) rateLimitEnabled() bool { + return m.Rate != 0 +} + +func (m *Module) concurrencyEnabled() bool { + return m.Concurrency > 0 } var ( diff --git a/modules/ratelimit/ratelimit_test.go b/modules/ratelimit/ratelimit_test.go index 7be29a1..23cc8c8 100644 --- a/modules/ratelimit/ratelimit_test.go +++ b/modules/ratelimit/ratelimit_test.go @@ -32,7 +32,7 @@ func TestRatelimit(t *testing.T) { }, }, &ratelimit.Module{ - Rate: 100, + Rate: 240, }, } @@ -41,12 +41,46 @@ func TestRatelimit(t *testing.T) { scraper.Modules = mods scraper.Run() - first := times[0].Add(-10 * time.Millisecond) - second := times[1].Add(-20 * time.Millisecond) + first := times[0].Add(-250 * time.Millisecond) + second := times[1].Add(-500 * time.Millisecond) - require.Less(t, first.Sub(start), 2*time.Millisecond) - require.Less(t, second.Sub(start), 2*time.Millisecond) + require.Less(t, first.Sub(start), 250*time.Millisecond) + require.Less(t, second.Sub(start), 250*time.Millisecond) - require.Less(t, start.Sub(first), 2*time.Millisecond) - require.Less(t, start.Sub(second), 2*time.Millisecond) + require.Less(t, start.Sub(first), 250*time.Millisecond) + require.Less(t, start.Sub(second), 250*time.Millisecond) +} + +func TestRatelimitConcurrency(t *testing.T) { + var times []time.Time + + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + times = append(times, time.Now()) + time.Sleep(10 * time.Millisecond) + return flyscrape.MockResponse(200, ` + <a href="foo"></a> + <a href="bar"></a> + <a href="baz"></a> + <a href="qux"></a> + `) + }) + }, + }, + &ratelimit.Module{ + Concurrency: 2, + }, + } + + scraper := flyscrape.NewScraper() + scraper.Modules = mods + scraper.Run() + + require.Len(t, times, 5) + require.Less(t, times[2].Sub(times[1]), time.Millisecond) + require.Less(t, times[4].Sub(times[3]), time.Millisecond) } diff --git a/modules/retry/retry.go b/modules/retry/retry.go new file mode 100644 index 0000000..9c00275 --- /dev/null +++ b/modules/retry/retry.go @@ -0,0 +1,145 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package retry + +import ( + "errors" + "io" + "net" + "net/http" + "slices" + "strconv" + "time" + + "github.com/philippta/flyscrape" +) + +func init() { + flyscrape.RegisterModule(Module{}) +} + +type Module struct { + ticker *time.Ticker + semaphore chan struct{} + + RetryDelays []time.Duration +} + +func (Module) ModuleInfo() flyscrape.ModuleInfo { + return flyscrape.ModuleInfo{ + ID: "retry", + New: func() flyscrape.Module { return new(Module) }, + } +} + +func (m *Module) Provision(flyscrape.Context) { + if m.RetryDelays == nil { + m.RetryDelays = defaultRetryDelays + } +} + +func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + resp, err := t.RoundTrip(r) + if !shouldRetry(resp, err) { + return resp, err + } + + for _, delay := range m.RetryDelays { + drainBody(resp, err) + + time.Sleep(retryAfter(resp, delay)) + + resp, err = t.RoundTrip(r) + if !shouldRetry(resp, err) { + break + } + } + + return resp, err + }) +} + +func shouldRetry(resp *http.Response, err error) bool { + statusCodes := []int{ + http.StatusRequestTimeout, + http.StatusTooEarly, + http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, + } + + if resp != nil { + if slices.Contains(statusCodes, resp.StatusCode) { + return true + } + } + if err == nil { + return false + } + if _, ok := err.(net.Error); ok { + return true + } + if errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + return false +} + +func drainBody(resp *http.Response, err error) { + if err == nil && resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } +} + +func retryAfter(resp *http.Response, fallback time.Duration) time.Duration { + if resp == nil { + return fallback + } + + timeexp := resp.Header.Get("Retry-After") + if timeexp == "" { + return fallback + } + + if seconds, err := strconv.Atoi(timeexp); err == nil { + return time.Duration(seconds) * time.Second + } + + formats := []string{ + time.RFC1123, // HTTP Spec + time.RFC1123Z, + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.RFC822, + time.RFC822Z, + time.RFC850, + time.RFC3339, + } + for _, format := range formats { + if t, err := time.Parse(format, timeexp); err == nil { + return t.Sub(time.Now()) + } + } + + return fallback +} + +var defaultRetryDelays = []time.Duration{ + 1 * time.Second, + 2 * time.Second, + 5 * time.Second, + 10 * time.Second, +} + +var ( + _ flyscrape.TransportAdapter = (*Module)(nil) + _ flyscrape.Provisioner = (*Module)(nil) +) diff --git a/modules/retry/retry_test.go b/modules/retry/retry_test.go new file mode 100644 index 0000000..b979320 --- /dev/null +++ b/modules/retry/retry_test.go @@ -0,0 +1,141 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package retry_test + +import ( + "fmt" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/philippta/flyscrape" + "github.com/philippta/flyscrape/modules/followlinks" + "github.com/philippta/flyscrape/modules/hook" + "github.com/philippta/flyscrape/modules/retry" + "github.com/philippta/flyscrape/modules/starturl" + "github.com/stretchr/testify/require" +) + +func TestRetry(t *testing.T) { + t.Parallel() + var count int + + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + count++ + return flyscrape.MockResponse(http.StatusServiceUnavailable, "service unavailable") + }) + }, + }, + &retry.Module{ + RetryDelays: []time.Duration{ + 100 * time.Millisecond, + 200 * time.Millisecond, + }, + }, + } + + scraper := flyscrape.NewScraper() + scraper.Modules = mods + scraper.Run() + + require.Equal(t, 3, count) +} + +func TestRetryStatusCodes(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + retry bool + }{ + {statusCode: http.StatusBadGateway, retry: true}, + {statusCode: http.StatusTooManyRequests, retry: true}, + {statusCode: http.StatusBadRequest, retry: false}, + {statusCode: http.StatusOK, retry: false}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s_%t", http.StatusText(test.statusCode), test.retry), func(t *testing.T) { + t.Parallel() + var count int + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + count++ + return flyscrape.MockResponse(test.statusCode, http.StatusText(test.statusCode)) + }) + }, + }, + &retry.Module{ + RetryDelays: []time.Duration{ + 100 * time.Millisecond, + 200 * time.Millisecond, + }, + }, + } + + scraper := flyscrape.NewScraper() + scraper.Modules = mods + scraper.Run() + + if test.retry { + require.NotEqual(t, 1, count) + } else { + require.Equal(t, 1, count) + } + }) + } +} + +func TestRetryErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + error error + }{ + {error: &net.OpError{}}, + {error: io.ErrUnexpectedEOF}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%T", test.error), func(t *testing.T) { + t.Parallel() + var count int + mods := []flyscrape.Module{ + &starturl.Module{URL: "http://www.example.com"}, + &followlinks.Module{}, + hook.Module{ + AdaptTransportFn: func(rt http.RoundTripper) http.RoundTripper { + return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) { + return nil, test.error + }) + }, + }, + &retry.Module{ + RetryDelays: []time.Duration{ + 100 * time.Millisecond, + 200 * time.Millisecond, + }, + }, + } + + scraper := flyscrape.NewScraper() + scraper.Modules = mods + scraper.Run() + + require.NotEqual(t, 1, count) + }) + } +} |