๐Ÿ“ฆ go-chi / httprate

๐Ÿ“„ local_counter.go ยท 99 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99package httprate

import (
	"sync"
	"time"

	"github.com/zeebo/xxh3"
)

// NewLocalLimitCounter creates an instance of localCounter,
// which is an in-memory implementation of http.LimitCounter.
//
// All methods are guaranteed to always return nil error.
func NewLocalLimitCounter(windowLength time.Duration) *localCounter {
	return &localCounter{
		windowLength:     windowLength,
		latestWindow:     time.Now().UTC().Truncate(windowLength),
		latestCounters:   make(map[uint64]int),
		previousCounters: make(map[uint64]int),
	}
}

var _ LimitCounter = (*localCounter)(nil)

type localCounter struct {
	windowLength     time.Duration
	latestWindow     time.Time
	latestCounters   map[uint64]int
	previousCounters map[uint64]int
	mu               sync.RWMutex
}

func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
	c.mu.Lock()
	defer c.mu.Unlock()

	c.evict(currentWindow)

	hkey := limitCounterKey(key)

	count, _ := c.latestCounters[hkey]
	c.latestCounters[hkey] = count + amount

	return nil
}

func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
	c.mu.RLock()
	defer c.mu.RUnlock()

	if c.latestWindow == currentWindow {
		curr, _ := c.latestCounters[limitCounterKey(key)]
		prev, _ := c.previousCounters[limitCounterKey(key)]
		return curr, prev, nil
	}

	if c.latestWindow == previousWindow {
		prev, _ := c.latestCounters[limitCounterKey(key)]
		return 0, prev, nil
	}

	return 0, 0, nil
}

func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
	c.windowLength = windowLength
	c.latestWindow = time.Now().UTC().Truncate(windowLength)
}

func (c *localCounter) Increment(key string, currentWindow time.Time) error {
	return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) evict(currentWindow time.Time) {
	if c.latestWindow == currentWindow {
		return
	}

	previousWindow := currentWindow.Add(-c.windowLength)
	if c.latestWindow == previousWindow {
		c.latestWindow = currentWindow
		// Shift the windows without map re-allocation.
		clear(c.previousCounters)
		c.latestCounters, c.previousCounters = c.previousCounters, c.latestCounters
		return
	}

	c.latestWindow = currentWindow

	clear(c.previousCounters)
	clear(c.latestCounters)
}

func limitCounterKey(key string) uint64 {
	h := xxh3.New()
	h.WriteString(key)
	return h.Sum64()
}