Skip to content

Commit

Permalink
feat: add bandwidth limiter in Route
Browse files Browse the repository at this point in the history
  • Loading branch information
zakuwaki committed Jun 22, 2023
1 parent 1d1db62 commit 681bad9
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 10 deletions.
7 changes: 7 additions & 0 deletions adapter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,20 @@ func RouterFromContext(ctx context.Context) Router {
return metadata.(Router)
}

type LimiterInfo struct {
Global bool
Download uint64
Upload uint64
}

type Rule interface {
Service
Type() string
UpdateGeosite() error
Match(metadata *InboundContext) bool
Outbound() string
String() string
LimiterInfo() *LimiterInfo
}

type DNSRule interface {
Expand Down
77 changes: 77 additions & 0 deletions experimental/limiter/bandwidth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package limiter

import (
"strconv"
"strings"

E "github.com/sagernet/sing/common/exceptions"
)

const (
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
)

type Bandwidth struct {
s string // KB MB GB
i uint64 // bytes
}

func NewBandwidth(s string) (bw Bandwidth, err error) {
err = bw.Parse(s)
if err != nil {
return
}
return
}

func (bw *Bandwidth) Equal(other *Bandwidth) bool {
if bw == nil && other == nil {
return true
}
if bw != nil && other != nil {
return bw.i == other.i
}
return false
}

func (bw *Bandwidth) Bytes() uint64 {
return bw.i
}

func (bw *Bandwidth) String() string {
return bw.s
}

func (bw *Bandwidth) Parse(s string) (err error) {
s = strings.TrimSpace(s)
if s == "" {
return
}

var (
unit uint64
cstr string
)
switch {
case strings.HasSuffix(s, "KB"):
unit = KB
cstr = strings.TrimSuffix(s, "KB")
case strings.HasSuffix(s, "MB"):
unit = MB
cstr = strings.TrimSuffix(s, "MB")
case strings.HasSuffix(s, "GB"):
unit = GB
cstr = strings.TrimSuffix(s, "GB")
default:
return E.New("invalid bandwidth value: ", s)
}
cnt, err := strconv.ParseUint(cstr, 10, 64)
if err != nil {
return
}
bw.s = s
bw.i = cnt * unit
return
}
95 changes: 95 additions & 0 deletions experimental/limiter/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package limiter

import (
"context"
"net"
"sync"

"golang.org/x/time/rate"
)

var m sync.Map

type limiter struct {
downloadLimiter *rate.Limiter
uploadLimiter *rate.Limiter
}

func newLimiter(download, upload uint64) *limiter {
var downloadLimiter, uploadLimiter *rate.Limiter
if download > 0 {
downloadLimiter = rate.NewLimiter(rate.Limit(float64(download)), int(download))
}
if upload > 0 {
uploadLimiter = rate.NewLimiter(rate.Limit(float64(upload)), int(upload))
}
return &limiter{downloadLimiter: downloadLimiter, uploadLimiter: uploadLimiter}
}

type connWithLimiter struct {
net.Conn
limiter *limiter
ctx context.Context
}

func NewConnWithLimiter(ctx context.Context, conn net.Conn, key string, global bool, download, upload uint64) net.Conn {
var l *limiter
if !global {
l = newLimiter(download, upload)
} else {
if v, ok := m.Load(key); ok {
l = v.(*limiter)
} else {
l = newLimiter(download, upload)
m.Store(key, l)
}
}
return &connWithLimiter{Conn: conn, limiter: l, ctx: ctx}
}

func (conn *connWithLimiter) Read(p []byte) (n int, err error) {
if conn.limiter == nil || conn.limiter.downloadLimiter == nil {
return conn.Conn.Read(p)
}
b := conn.limiter.downloadLimiter.Burst()
if b < len(p) {
p = p[:b]
}
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.downloadLimiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}

func (conn *connWithLimiter) Write(p []byte) (n int, err error) {
if conn.limiter == nil || conn.limiter.uploadLimiter == nil {
return conn.Conn.Write(p)
}
var nn int
b := conn.limiter.uploadLimiter.Burst()
for {
end := len(p)
if end == 0 {
break
}
if b < len(p) {
end = b
}
err = conn.limiter.uploadLimiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.Conn.Write(p[:end])
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
14 changes: 11 additions & 3 deletions option/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (
)

type _Rule struct {
Type string `json:"type,omitempty"`
DefaultOptions DefaultRule `json:"-"`
LogicalOptions LogicalRule `json:"-"`
Type string `json:"type,omitempty"`
DefaultOptions DefaultRule `json:"-"`
LogicalOptions LogicalRule `json:"-"`
LimiterOptions *LimiterRule `json:"limiter,omitempty"`
}

type Rule _Rule
Expand Down Expand Up @@ -99,3 +100,10 @@ type LogicalRule struct {
func (r LogicalRule) IsValid() bool {
return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid)
}

type LimiterRule struct {
Enabled bool `json:"enabled"`
Global bool `json:"global"`
DownloadBandwidth string `json:"download_bandwidth,omitempty"`
UploadBandwidth string `json:"upload_bandwidth,omitempty"`
}
7 changes: 7 additions & 0 deletions route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/experimental/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/ntp"
"github.com/sagernet/sing-box/option"
Expand Down Expand Up @@ -688,6 +689,12 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
if !common.Contains(detour.Network(), N.NetworkTCP) {
return E.New("missing supported outbound, closing connection")
}
if matchedRule != nil {
if li := matchedRule.LimiterInfo(); li != nil {
key := matchedRule.String()

Check failure on line 694 in route/router.go

View workflow job for this annotation

GitHub Actions / Build

File is not `gci`-ed with --skip-generated -s standard -s prefix(github.com/sagernet/) -s default --custom-order (gci)
conn = limiter.NewConnWithLimiter(ctx, conn, key, li.Global, li.Download, li.Upload)
}
}
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule)
defer tracker.Leave()
Expand Down
2 changes: 1 addition & 1 deletion route/router_geo_resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (r *Router) LoadGeosite(code string) (adapter.Rule, error) {
if err != nil {
return nil, err
}
rule, err = NewDefaultRule(r, nil, geosite.Compile(items))
rule, err = NewDefaultRule(r, nil, geosite.Compile(items), nil)
if err != nil {
return nil, err
}
Expand Down
46 changes: 40 additions & 6 deletions route/rule_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,34 @@ package route
import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)

func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) {
func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (rule adapter.Rule, err error) {
var limiterInfo *adapter.LimiterInfo
if lo := options.LimiterOptions; lo != nil && lo.Enabled {
var download, upload limiter.Bandwidth
if len(lo.DownloadBandwidth) > 0 {
download, err = limiter.NewBandwidth(lo.DownloadBandwidth)
if err != nil {
return
}
}
if len(lo.UploadBandwidth) > 0 {
upload, err = limiter.NewBandwidth(lo.UploadBandwidth)
if err != nil {
return
}
}
limiterInfo = &adapter.LimiterInfo{
Global: lo.Global,
Download: download.Bytes(),
Upload: upload.Bytes()}

Check failure on line 31 in route/rule_default.go

View workflow job for this annotation

GitHub Actions / Build

File is not `gofumpt`-ed (gofumpt)
}

switch options.Type {
case "", C.RuleTypeDefault:
if !options.DefaultOptions.IsValid() {
Expand All @@ -17,15 +39,15 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul
if options.DefaultOptions.Outbound == "" {
return nil, E.New("missing outbound field")
}
return NewDefaultRule(router, logger, options.DefaultOptions)
return NewDefaultRule(router, logger, options.DefaultOptions, limiterInfo)
case C.RuleTypeLogical:
if !options.LogicalOptions.IsValid() {
return nil, E.New("missing conditions")
}
if options.LogicalOptions.Outbound == "" {
return nil, E.New("missing outbound field")
}
return NewLogicalRule(router, logger, options.LogicalOptions)
return NewLogicalRule(router, logger, options.LogicalOptions, limiterInfo)
default:
return nil, E.New("unknown rule type: ", options.Type)
}
Expand All @@ -35,19 +57,25 @@ var _ adapter.Rule = (*DefaultRule)(nil)

type DefaultRule struct {
abstractDefaultRule
limiterInfo *adapter.LimiterInfo
}

func (r *DefaultRule) LimiterInfo() *adapter.LimiterInfo {
return r.limiterInfo
}

type RuleItem interface {
Match(metadata *adapter.InboundContext) bool
String() string
}

func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule, limiterInfo *adapter.LimiterInfo) (*DefaultRule, error) {
rule := &DefaultRule{
abstractDefaultRule{
invert: options.Invert,
outbound: options.Outbound,
},
limiterInfo,
}
if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound)
Expand Down Expand Up @@ -191,15 +219,21 @@ var _ adapter.Rule = (*LogicalRule)(nil)

type LogicalRule struct {
abstractLogicalRule
limiterInfo *adapter.LimiterInfo
}

func (r *LogicalRule) LimiterInfo() *adapter.LimiterInfo {
return r.limiterInfo
}

func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule, limiterInfo *adapter.LimiterInfo) (*LogicalRule, error) {
r := &LogicalRule{
abstractLogicalRule{
rules: make([]adapter.Rule, len(options.Rules)),
invert: options.Invert,
outbound: options.Outbound,
},
limiterInfo,
}
switch options.Mode {
case C.LogicalTypeAnd:
Expand All @@ -210,7 +244,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule)
rule, err := NewDefaultRule(router, logger, subRule, nil)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
Expand Down
8 changes: 8 additions & 0 deletions route/rule_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type DefaultDNSRule struct {
rewriteTTL *uint32
}

func (r *DefaultDNSRule) LimiterInfo() *adapter.LimiterInfo {
return nil
}

func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
rule := &DefaultDNSRule{
abstractDefaultRule: abstractDefaultRule{
Expand Down Expand Up @@ -199,6 +203,10 @@ type LogicalDNSRule struct {
rewriteTTL *uint32
}

func (r *LogicalDNSRule) LimiterInfo() *adapter.LimiterInfo {
return nil
}

func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
abstractLogicalRule: abstractLogicalRule{
Expand Down

0 comments on commit 681bad9

Please sign in to comment.