summaryrefslogblamecommitdiff
path: root/modules/retry/retry.go
blob: 09cbdbd8ebe2c225ed09a69d87e5c9c67ecb6d6b (plain) (tree)

































































                                                                                      
                                     














































































                                                                            
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package retry

import (
	"errors"
	"io"
	"net"
	"net/http"
	"slices"
	"strconv"
	"time"

	"github.com/philippta/flyscrape"
)

func init() {
	flyscrape.RegisterModule(Module{})
}

type Module struct {
	ticker    *time.Ticker
	semaphore chan struct{}

	RetryDelays []time.Duration
}

func (Module) ModuleInfo() flyscrape.ModuleInfo {
	return flyscrape.ModuleInfo{
		ID:  "retry",
		New: func() flyscrape.Module { return new(Module) },
	}
}

func (m *Module) Provision(flyscrape.Context) {
	if m.RetryDelays == nil {
		m.RetryDelays = defaultRetryDelays
	}
}

func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
	return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
		resp, err := t.RoundTrip(r)
		if !shouldRetry(resp, err) {
			return resp, err
		}

		for _, delay := range m.RetryDelays {
			drainBody(resp, err)

			time.Sleep(retryAfter(resp, delay))

			resp, err = t.RoundTrip(r)
			if !shouldRetry(resp, err) {
				break
			}
		}

		return resp, err
	})
}

func shouldRetry(resp *http.Response, err error) bool {
	statusCodes := []int{
		http.StatusForbidden,
		http.StatusRequestTimeout,
		http.StatusTooEarly,
		http.StatusTooManyRequests,
		http.StatusInternalServerError,
		http.StatusBadGateway,
		http.StatusServiceUnavailable,
		http.StatusGatewayTimeout,
	}

	if resp != nil {
		if slices.Contains(statusCodes, resp.StatusCode) {
			return true
		}
	}
	if err == nil {
		return false
	}
	if _, ok := err.(net.Error); ok {
		return true
	}
	if errors.Is(err, io.ErrUnexpectedEOF) {
		return true
	}

	return false
}

func drainBody(resp *http.Response, err error) {
	if err == nil && resp != nil && resp.Body != nil {
		io.Copy(io.Discard, resp.Body)
		resp.Body.Close()
	}
}

func retryAfter(resp *http.Response, fallback time.Duration) time.Duration {
	if resp == nil {
		return fallback
	}

	timeexp := resp.Header.Get("Retry-After")
	if timeexp == "" {
		return fallback
	}

	if seconds, err := strconv.Atoi(timeexp); err == nil {
		return time.Duration(seconds) * time.Second
	}

	formats := []string{
		time.RFC1123, // HTTP Spec
		time.RFC1123Z,
		time.ANSIC,
		time.UnixDate,
		time.RubyDate,
		time.RFC822,
		time.RFC822Z,
		time.RFC850,
		time.RFC3339,
	}
	for _, format := range formats {
		if t, err := time.Parse(format, timeexp); err == nil {
			return t.Sub(time.Now())
		}
	}

	return fallback
}

var defaultRetryDelays = []time.Duration{
	1 * time.Second,
	2 * time.Second,
	5 * time.Second,
	10 * time.Second,
}

var (
	_ flyscrape.TransportAdapter = (*Module)(nil)
	_ flyscrape.Provisioner      = (*Module)(nil)
)