summaryrefslogtreecommitdiff
path: root/modules/cache/cache.go
blob: 716450672bb43b59518b38a4804c3cffcb3c772c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package cache

import (
	"bufio"
	"bytes"
	"net/http"
	"net/http/httputil"
	"path/filepath"
	"strings"

	"github.com/philippta/flyscrape"
)

func init() {
	flyscrape.RegisterModule(Module{})
}

type Module struct {
	Cache string `json:"cache"`

	store Store
}

func (Module) ModuleInfo() flyscrape.ModuleInfo {
	return flyscrape.ModuleInfo{
		ID:  "cache",
		New: func() flyscrape.Module { return new(Module) },
	}
}

func (m *Module) Provision(ctx flyscrape.Context) {
	switch {
	case m.Cache == "file":
		file := replaceExt(ctx.ScriptName(), ".cache")
		m.store = NewBoltStore(file)

	case strings.HasPrefix(m.Cache, "file:"):
		m.store = NewBoltStore(strings.TrimPrefix(m.Cache, "file:"))
	}
}

func (m *Module) AdaptTransport(t http.RoundTripper) http.RoundTripper {
	if m.store == nil {
		return t
	}
	return flyscrape.RoundTripFunc(func(r *http.Request) (*http.Response, error) {
		if nocache(r) {
			return t.RoundTrip(r)
		}

		key := r.Method + " " + r.URL.String()
		if b, ok := m.store.Get(key); ok {
			if resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(b)), r); err == nil {
				return resp, nil
			}
		}

		resp, err := t.RoundTrip(r)
		if err != nil {
			return resp, err
		}

		// Avoid caching when running into rate limits or
		// when the page errored.
		if resp.StatusCode < 200 || resp.StatusCode > 299 {
			return resp, err
		}

		encoded, err := httputil.DumpResponse(resp, true)
		if err != nil {
			return resp, err
		}

		m.store.Set(key, encoded)
		return resp, nil
	})
}

func (m *Module) Finalize() {
	if v, ok := m.store.(interface{ Close() }); ok {
		v.Close()
	}
}

func nocache(r *http.Request) bool {
	if r.Header.Get(flyscrape.HeaderBypassCache) != "" {
		r.Header.Del(flyscrape.HeaderBypassCache)
		return true
	}
	return false
}

func replaceExt(filePath string, newExt string) string {
	ext := filepath.Ext(filePath)
	if ext != "" {
		fileNameWithoutExt := filePath[:len(filePath)-len(ext)]
		newFilePath := fileNameWithoutExt + newExt
		return newFilePath
	}
	return filePath + newExt
}

type Store interface {
	Get(key string) ([]byte, bool)
	Set(key string, value []byte)
}