-
Notifications
You must be signed in to change notification settings - Fork 1
/
timing_wheel.go
128 lines (112 loc) · 2.76 KB
/
timing_wheel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package timing_wheel
import (
"time"
"errors"
"sync"
"sync/atomic"
"unsafe"
)
type TimingWheel struct {
ticker time.Ticker
tickMs int
wheelSize int
slots []*slot
slotIndex int
tick *time.Ticker
prevWheel *TimingWheel
nextWheel unsafe.Pointer
level int
exit chan struct{}
sync.WaitGroup
}
func NewTimingWheel(tick time.Duration, wheelSize int) (*TimingWheel, error) {
if wheelSize <= 0 {
return nil, errors.New("wheel size should > 0")
}
tickMs := int(tick / time.Millisecond)
if tickMs <= 0 {
return nil, errors.New("tick should > 0")
}
return newTimingWheel(tickMs, wheelSize, 0, nil), nil
}
func newTimingWheel(tick, wheelSize, level int, prev *TimingWheel) *TimingWheel {
slots := make([]*slot, wheelSize)
for i := 0; i < wheelSize; i ++ {
slots[i] = newSlot()
}
return &TimingWheel {
tickMs: tick,
wheelSize: wheelSize,
slots: slots,
prevWheel: prev,
nextWheel: nil,
level: level,
exit: make(chan struct{}),
}
}
func (tw *TimingWheel) Run() {
tw.tick = time.NewTicker(time.Millisecond * time.Duration(tw.tickMs))
tw.Add(1)
go func() {
for {
select {
case <- tw.tick.C:
tw.advance()
case <- tw.exit:
goto done
}
}
done:
tw.Done()
}()
}
func (tw *TimingWheel) advance() {
slot := tw.slots[tw.slotIndex]
if 0 == tw.level {
slot.trigger()
} else {
timers := slot.getClear()
curms := time.Now().UnixNano() / int64(time.Millisecond)
for t := timers.Front(); t != nil; t = t.Next() {
t := t.Value.(*timer)
tw.prevWheel._addTimer(int(t.expiredTime - curms), t)
}
}
tw.slotIndex = (tw.slotIndex + 1) % tw.wheelSize
if 0 == tw.slotIndex {
nextTw := atomic.LoadPointer(&tw.nextWheel)
if nil != nextTw {
(*TimingWheel)(nextTw).advance()
}
}
}
func (tw *TimingWheel) _addTimer(duration int, t *timer) {
if duration <= 0 {
tw.slots[tw.slotIndex].add(t)
return
}
index := (int(duration / tw.tickMs) + tw.slotIndex) % tw.wheelSize
tw.slots[index].add(t)
}
func (tw *TimingWheel) addTimer(duration int, t *timer) {
if duration > tw.wheelSize * tw.tickMs {
nextWheel := atomic.LoadPointer(&tw.nextWheel)
if nil == nextWheel {
newTw := newTimingWheel(tw.tickMs * tw.wheelSize, tw.wheelSize, tw.level + 1, tw)
atomic.CompareAndSwapPointer(&tw.nextWheel, nil, unsafe.Pointer(newTw))
nextWheel = atomic.LoadPointer(&tw.nextWheel)
}
(*TimingWheel)(nextWheel).addTimer(duration - tw.wheelSize * tw.tickMs, t)
} else {
tw._addTimer(duration, t)
}
}
func (tw *TimingWheel) AfterFunc(duration time.Duration, callback func()) {
durationMs := int(duration / time.Millisecond)
expriedTime := time.Now().UnixNano() / int64(time.Millisecond) + int64(durationMs)
tw.addTimer(durationMs, &timer{expriedTime, callback})
}
func (tw *TimingWheel) Stop() {
close(tw.exit)
tw.Wait()
}