From e75f9ff8ec4c7d810f1eb0e6b9f127d864a9b2f7 Mon Sep 17 00:00:00 2001 From: v-byte-cpu <65545655+v-byte-cpu@users.noreply.github.com> Date: Sat, 26 Jun 2021 20:58:43 +0300 Subject: [PATCH] feature: randomized port iterator (#90) --- pkg/ip/ip.go | 15 -- pkg/ip/ip_test.go | 50 ------ pkg/scan/mock_request_test.go | 74 ++------- pkg/scan/request.go | 57 +++++-- pkg/scan/request_test.go | 301 +++++++++++++++++++++------------- 5 files changed, 248 insertions(+), 249 deletions(-) diff --git a/pkg/ip/ip.go b/pkg/ip/ip.go index 1880a1f..09766d2 100644 --- a/pkg/ip/ip.go +++ b/pkg/ip/ip.go @@ -8,21 +8,6 @@ import ( var ErrInvalidAddr = errors.New("invalid IP subnet/host") -func Inc(ip net.IP) { - for j := len(ip) - 1; j >= 0; j-- { - ip[j]++ - if ip[j] > 0 { - break - } - } -} - -func DupIP(ip net.IP) net.IP { - dup := make([]byte, 4) - copy(dup, ip.To4()) - return dup -} - func ParseIPNet(subnet string) (*net.IPNet, error) { _, result, err := net.ParseCIDR(subnet) if err == nil { diff --git a/pkg/ip/ip_test.go b/pkg/ip/ip_test.go index 4e6a95a..66bf3d8 100644 --- a/pkg/ip/ip_test.go +++ b/pkg/ip/ip_test.go @@ -7,56 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestInc(t *testing.T) { - t.Parallel() - tests := []struct { - name string - input net.IP - expected net.IP - }{ - { - name: "ZeroNet", - input: net.IPv4(0, 0, 0, 0), - expected: net.IPv4(0, 0, 0, 1), - }, - { - name: "Inc3rd", - input: net.IPv4(1, 1, 0, 255), - expected: net.IPv4(1, 1, 1, 0), - }, - { - name: "Inc2nd", - input: net.IPv4(1, 1, 255, 255), - expected: net.IPv4(1, 2, 0, 0), - }, - { - name: "Inc1st", - input: net.IPv4(1, 255, 255, 255), - expected: net.IPv4(2, 0, 0, 0), - }, - } - - for _, vtt := range tests { - tt := vtt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - Inc(tt.input) - assert.Equal(t, tt.expected, tt.input) - }) - } -} - -func TestDupIP(t *testing.T) { - t.Parallel() - ipAddr := net.IPv4(192, 168, 0, 1).To4() - - dupAddr := DupIP(ipAddr) - assert.Equal(t, ipAddr, dupAddr) - - dupAddr[3]++ - assert.Equal(t, net.IPv4(192, 168, 0, 1).To4(), ipAddr) -} - func TestParseIPNetWithError(t *testing.T) { t.Parallel() _, err := ParseIPNet("") diff --git a/pkg/scan/mock_request_test.go b/pkg/scan/mock_request_test.go index 6fd05cb..a27922e 100644 --- a/pkg/scan/mock_request_test.go +++ b/pkg/scan/mock_request_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: request.go +// Source: github.com/v-byte-cpu/sx/pkg/scan (interfaces: PortGenerator,IPGenerator,RequestGenerator,IPContainer) // Package scan is a generated GoMock package. package scan @@ -36,56 +36,18 @@ func (m *MockPortGenerator) EXPECT() *MockPortGeneratorMockRecorder { } // Ports mocks base method. -func (m *MockPortGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error) { +func (m *MockPortGenerator) Ports(arg0 context.Context, arg1 *Range) (<-chan PortGetter, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Ports", ctx, r) - ret0, _ := ret[0].(<-chan uint16) + ret := m.ctrl.Call(m, "Ports", arg0, arg1) + ret0, _ := ret[0].(<-chan PortGetter) ret1, _ := ret[1].(error) return ret0, ret1 } // Ports indicates an expected call of Ports. -func (mr *MockPortGeneratorMockRecorder) Ports(ctx, r interface{}) *gomock.Call { +func (mr *MockPortGeneratorMockRecorder) Ports(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ports", reflect.TypeOf((*MockPortGenerator)(nil).Ports), ctx, r) -} - -// MockIPGetter is a mock of IPGetter interface. -type MockIPGetter struct { - ctrl *gomock.Controller - recorder *MockIPGetterMockRecorder -} - -// MockIPGetterMockRecorder is the mock recorder for MockIPGetter. -type MockIPGetterMockRecorder struct { - mock *MockIPGetter -} - -// NewMockIPGetter creates a new mock instance. -func NewMockIPGetter(ctrl *gomock.Controller) *MockIPGetter { - mock := &MockIPGetter{ctrl: ctrl} - mock.recorder = &MockIPGetterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockIPGetter) EXPECT() *MockIPGetterMockRecorder { - return m.recorder -} - -// GetIP mocks base method. -func (m *MockIPGetter) GetIP() (net.IP, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetIP") - ret0, _ := ret[0].(net.IP) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetIP indicates an expected call of GetIP. -func (mr *MockIPGetterMockRecorder) GetIP() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIP", reflect.TypeOf((*MockIPGetter)(nil).GetIP)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ports", reflect.TypeOf((*MockPortGenerator)(nil).Ports), arg0, arg1) } // MockIPGenerator is a mock of IPGenerator interface. @@ -112,18 +74,18 @@ func (m *MockIPGenerator) EXPECT() *MockIPGeneratorMockRecorder { } // IPs mocks base method. -func (m *MockIPGenerator) IPs(ctx context.Context, r *Range) (<-chan IPGetter, error) { +func (m *MockIPGenerator) IPs(arg0 context.Context, arg1 *Range) (<-chan IPGetter, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IPs", ctx, r) + ret := m.ctrl.Call(m, "IPs", arg0, arg1) ret0, _ := ret[0].(<-chan IPGetter) ret1, _ := ret[1].(error) return ret0, ret1 } // IPs indicates an expected call of IPs. -func (mr *MockIPGeneratorMockRecorder) IPs(ctx, r interface{}) *gomock.Call { +func (mr *MockIPGeneratorMockRecorder) IPs(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IPs", reflect.TypeOf((*MockIPGenerator)(nil).IPs), ctx, r) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IPs", reflect.TypeOf((*MockIPGenerator)(nil).IPs), arg0, arg1) } // MockRequestGenerator is a mock of RequestGenerator interface. @@ -150,18 +112,18 @@ func (m *MockRequestGenerator) EXPECT() *MockRequestGeneratorMockRecorder { } // GenerateRequests mocks base method. -func (m *MockRequestGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { +func (m *MockRequestGenerator) GenerateRequests(arg0 context.Context, arg1 *Range) (<-chan *Request, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GenerateRequests", ctx, r) + ret := m.ctrl.Call(m, "GenerateRequests", arg0, arg1) ret0, _ := ret[0].(<-chan *Request) ret1, _ := ret[1].(error) return ret0, ret1 } // GenerateRequests indicates an expected call of GenerateRequests. -func (mr *MockRequestGeneratorMockRecorder) GenerateRequests(ctx, r interface{}) *gomock.Call { +func (mr *MockRequestGeneratorMockRecorder) GenerateRequests(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateRequests", reflect.TypeOf((*MockRequestGenerator)(nil).GenerateRequests), ctx, r) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateRequests", reflect.TypeOf((*MockRequestGenerator)(nil).GenerateRequests), arg0, arg1) } // MockIPContainer is a mock of IPContainer interface. @@ -188,16 +150,16 @@ func (m *MockIPContainer) EXPECT() *MockIPContainerMockRecorder { } // Contains mocks base method. -func (m *MockIPContainer) Contains(ip net.IP) (bool, error) { +func (m *MockIPContainer) Contains(arg0 net.IP) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Contains", ip) + ret := m.ctrl.Call(m, "Contains", arg0) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // Contains indicates an expected call of Contains. -func (mr *MockIPContainerMockRecorder) Contains(ip interface{}) *gomock.Call { +func (mr *MockIPContainerMockRecorder) Contains(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockIPContainer)(nil).Contains), ip) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockIPContainer)(nil).Contains), arg0) } diff --git a/pkg/scan/request.go b/pkg/scan/request.go index d671834..5d4301c 100644 --- a/pkg/scan/request.go +++ b/pkg/scan/request.go @@ -1,4 +1,4 @@ -//go:generate mockgen -package scan -destination=mock_request_test.go -source request.go +//go:generate mockgen -package scan -destination=mock_request_test.go . PortGenerator,IPGenerator,RequestGenerator,IPContainer //go:generate easyjson -output_filename request_easyjson.go request.go package scan @@ -31,30 +31,52 @@ type Request struct { Err error } +type PortGetter interface { + GetPort() (uint16, error) +} + +type WrapPort uint16 + +func (p WrapPort) GetPort() (uint16, error) { + return uint16(p), nil +} + +type portError struct { + error +} + +func (err *portError) GetPort() (uint16, error) { + return 0, err +} + type PortGenerator interface { - Ports(ctx context.Context, r *Range) (<-chan uint16, error) + Ports(ctx context.Context, r *Range) (<-chan PortGetter, error) } func NewPortGenerator() PortGenerator { return &portGenerator{} } -// TODO randomizedPortGenerator type portGenerator struct{} -func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error) { +func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan PortGetter, error) { if err := validatePorts(r.Ports); err != nil { return nil, err } - out := make(chan uint16, 100) + out := make(chan PortGetter, 100) go func() { defer close(out) for _, portRange := range r.Ports { - for port := int(portRange.StartPort); port <= int(portRange.EndPort); port++ { - select { - case <-ctx.Done(): - return - case out <- uint16(port): + it, err := newRangeIterator(int64(portRange.EndPort) - int64(portRange.StartPort) + 1) + if err != nil { + writePort(ctx, out, &portError{err}) + continue + } + basePort := int64(portRange.StartPort) - 1 + for { + writePort(ctx, out, WrapPort(basePort+it.Int().Int64())) + if !it.Next() { + break } } } @@ -62,6 +84,14 @@ func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error return out, nil } +func writePort(ctx context.Context, out chan<- PortGetter, port PortGetter) { + select { + case <-ctx.Done(): + return + case out <- port: + } +} + func validatePorts(ports []*PortRange) error { if len(ports) == 0 { return ErrPortRange @@ -153,7 +183,12 @@ func (rg *ipPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-ch out := make(chan *Request, 100) go func() { defer close(out) - for port := range ports { + for p := range ports { + port, err := p.GetPort() + if err != nil { + writeRequest(ctx, out, &Request{Err: err}) + continue + } for ipaddr := range ips { dstip, err := ipaddr.GetIP() writeRequest(ctx, out, &Request{ diff --git a/pkg/scan/request_test.go b/pkg/scan/request_test.go index 9d1bd73..cd32b95 100644 --- a/pkg/scan/request_test.go +++ b/pkg/scan/request_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "io/ioutil" + "math/big" "net" "sort" "strings" @@ -82,7 +83,7 @@ func withError(err error) scanRequestOption { } } -func chanPortToGeneric(in <-chan uint16) <-chan interface{} { +func chanPortToGeneric(in <-chan PortGetter) <-chan interface{} { out := make(chan interface{}, cap(in)) go func() { defer close(out) @@ -139,7 +140,7 @@ func TestPortGenerator(t *testing.T) { EndPort: 22, }, })), - expected: []interface{}{uint16(22)}, + expected: []interface{}{WrapPort(22)}, }, { name: "TwoPorts", @@ -149,7 +150,7 @@ func TestPortGenerator(t *testing.T) { EndPort: 23, }, })), - expected: []interface{}{uint16(22), uint16(23)}, + expected: []interface{}{WrapPort(22), WrapPort(23)}, }, { name: "ThreePorts", @@ -159,7 +160,7 @@ func TestPortGenerator(t *testing.T) { EndPort: 27, }, })), - expected: []interface{}{uint16(25), uint16(26), uint16(27)}, + expected: []interface{}{WrapPort(25), WrapPort(26), WrapPort(27)}, }, { name: "OnePortOverflow", @@ -169,7 +170,7 @@ func TestPortGenerator(t *testing.T) { EndPort: 65535, }, })), - expected: []interface{}{uint16(65535)}, + expected: []interface{}{WrapPort(65535)}, }, { name: "TwoRangesOnePort", @@ -183,21 +184,32 @@ func TestPortGenerator(t *testing.T) { EndPort: 27, }, })), - expected: []interface{}{uint16(25), uint16(27)}, + expected: []interface{}{WrapPort(25), WrapPort(27)}, }, { name: "TwoRangesTwoPorts", scanRange: newScanRange(withPorts([]*PortRange{ { - StartPort: 21, - EndPort: 22, + StartPort: 20, + EndPort: 21, }, { - StartPort: 26, + StartPort: 23, EndPort: 27, }, })), - expected: []interface{}{uint16(21), uint16(22), uint16(26), uint16(27)}, + expected: []interface{}{WrapPort(20), WrapPort(21), WrapPort(23), + WrapPort(24), WrapPort(25), WrapPort(26), WrapPort(27)}, + }, + { + name: "ZeroPort", + scanRange: newScanRange(withPorts([]*PortRange{ + { + StartPort: 0, + EndPort: 1, + }, + })), + expected: []interface{}{WrapPort(0), WrapPort(1)}, }, } @@ -217,6 +229,9 @@ func TestPortGenerator(t *testing.T) { } require.NoError(t, err) result := chanToSlice(t, chanPortToGeneric(ports), len(tt.expected)) + sort.Slice(result, func(i, j int) bool { + return uint16(result[i].(WrapPort)) < uint16(result[j].(WrapPort)) + }) require.Equal(t, tt.expected, result) }() waitDone(t, done) @@ -224,6 +239,41 @@ func TestPortGenerator(t *testing.T) { } } +func TestPortGeneratorFullRange(t *testing.T) { + t.Parallel() + done := make(chan interface{}) + go func() { + defer close(done) + portgen := NewPortGenerator() + ports, err := portgen.Ports(context.Background(), newScanRange(withPorts([]*PortRange{ + { + StartPort: 1, + EndPort: 65535, + }, + }))) + require.NoError(t, err) + + bitset := big.NewInt(0) + cnt := 0 + for p := range ports { + cnt++ + port, err := p.GetPort() + require.NoError(t, err) + i := int(port) + if bitset.Bit(i) == 1 { + require.Fail(t, "number has already been visited", "number %d", i) + } + bitset.SetBit(bitset, i, 1) + } + for i := 1; i <= 65535; i++ { + require.Equal(t, uint(1), bitset.Bit(i), + "number %d is not visited", i) + } + require.Equal(t, 65535, cnt, "count is not valid") + }() + waitDone(t, done) +} + func chanIPToGeneric(in <-chan IPGetter) <-chan interface{} { out := make(chan interface{}, cap(in)) go func() { @@ -324,123 +374,85 @@ func TestIPPortGenerator(t *testing.T) { tests := []struct { name string - input *Range + ips []IPGetter + ports []PortGetter expected []interface{} - err bool }{ { - name: "InvalidPortRange", - input: newScanRange( - withPorts([]*PortRange{ - { - StartPort: 5000, - EndPort: 2000, - }, - }), - ), - err: true, - }, - { - name: "NilSubnet", - input: newScanRange(withSubnet(nil)), - err: true, + name: "OneIpOnePort", + ips: []IPGetter{WrapIP(net.IPv4(192, 168, 0, 1))}, + ports: []PortGetter{WrapPort(888)}, + expected: []interface{}{ + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(888)), + }, }, { - name: "OneIpOnePort", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(32, 32)}), - withPorts([]*PortRange{ - { - StartPort: 888, - EndPort: 888, - }, - }), - ), + name: "OneIpTwoPorts", + ips: []IPGetter{WrapIP(net.IPv4(192, 168, 0, 1))}, + ports: []PortGetter{WrapPort(888), WrapPort(889)}, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(889)), }, }, { - name: "OneIpTwoPorts", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(32, 32)}), - withPorts([]*PortRange{ - { - StartPort: 888, - EndPort: 889, - }, - }), - ), + name: "ThreeIpsOnePort", + ips: []IPGetter{ + WrapIP(net.IPv4(192, 168, 0, 1)), + WrapIP(net.IPv4(192, 168, 0, 2)), + WrapIP(net.IPv4(192, 168, 0, 3)), + }, + ports: []PortGetter{WrapPort(888)}, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(889)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 2)), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 3)), withDstPort(888)), }, }, { - name: "TwoIpsOnePort", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(31, 32)}), - withPorts([]*PortRange{ - { - StartPort: 888, - EndPort: 888, - }, - }), - ), + name: "TwoIpsTwoPorts", + ips: []IPGetter{ + WrapIP(net.IPv4(192, 168, 0, 1)), + WrapIP(net.IPv4(192, 168, 0, 2)), + }, + ports: []PortGetter{WrapPort(888), WrapPort(889)}, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 2)), withDstPort(888)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(889)), + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 2)), withDstPort(889)), }, }, { - name: "FourIpsOnePort", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(30, 32)}), - withPorts([]*PortRange{ - { - StartPort: 888, - EndPort: 888, - }, - }), - ), + name: "IPError", + ips: []IPGetter{ + &ipError{errors.New("ip error")}, + }, + ports: []PortGetter{WrapPort(888)}, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 2).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 3).To4()), withDstPort(888)), + newScanRequest(withDstIP(nil), withDstPort(888), withError(&ipError{errors.New("ip error")})), }, }, { - name: "TwoIpsTwoPorts", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(31, 32)}), - withPorts([]*PortRange{ - { - StartPort: 888, - EndPort: 889, - }, - }), - ), + name: "PortError", + ips: []IPGetter{WrapIP(net.IPv4(192, 168, 0, 1))}, + ports: []PortGetter{ + &portError{errors.New("port error")}, + }, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(889)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(888)), - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(889)), + &Request{Err: &portError{errors.New("port error")}}, }, }, { - name: "OneIpPortOverflow", - input: newScanRange( - withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1).To4(), Mask: net.CIDRMask(32, 32)}), - withPorts([]*PortRange{ - { - StartPort: 65535, - EndPort: 65535, - }, - }), - ), + name: "ValidPortAfterPortError", + ips: []IPGetter{WrapIP(net.IPv4(192, 168, 0, 1))}, + ports: []PortGetter{ + &portError{errors.New("port error")}, + WrapPort(888), + }, expected: []interface{}{ - newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(65535)), + &Request{Err: &portError{errors.New("port error")}}, + newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1)), withDstPort(888)), }, }, } @@ -454,25 +466,34 @@ func TestIPPortGenerator(t *testing.T) { go func() { defer close(done) - reqgen := NewIPPortGenerator(NewIPGenerator(), NewPortGenerator()) - pairs, err := reqgen.GenerateRequests(context.Background(), tt.input) - if tt.err { - require.Error(t, err) - return + ctrl := gomock.NewController(t) + ipgen := NewMockIPGenerator(ctrl) + + ctx := context.Background() + scanRange := newScanRange() + ipgen.EXPECT().IPs(ctx, scanRange). + DoAndReturn(func(ctx context.Context, r *Range) (<-chan IPGetter, error) { + ips := make(chan IPGetter, len(tt.ips)) + for _, ip := range tt.ips { + ips <- ip + } + close(ips) + return ips, nil + }).AnyTimes() + + ports := make(chan PortGetter, len(tt.ports)) + for _, port := range tt.ports { + ports <- port } + close(ports) + + portgen := NewMockPortGenerator(ctrl) + portgen.EXPECT().Ports(ctx, scanRange).Return(ports, nil) + + reqgen := NewIPPortGenerator(ipgen, portgen) + pairs, err := reqgen.GenerateRequests(ctx, scanRange) require.NoError(t, err) result := chanToSlice(t, chanPairToGeneric(pairs), len(tt.expected)) - sort.Slice(result, func(i, j int) bool { - req1 := result[i].(*Request) - req2 := result[j].(*Request) - switch bytes.Compare([]byte(req1.DstIP), []byte(req2.DstIP)) { - case -1: - return true - case 1: - return false - } - return req1.DstPort < req2.DstPort - }) require.Equal(t, tt.expected, result) }() waitDone(t, done) @@ -480,6 +501,52 @@ func TestIPPortGenerator(t *testing.T) { } } +func TestIPPortGeneratorError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ipsError error + portsError error + }{ + { + name: "IPGeneratorError", + ipsError: errors.New("ipgen error"), + }, + { + name: "PortGeneratorError", + ipsError: errors.New("portgen error"), + }, + } + + for _, vtt := range tests { + tt := vtt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + ipgen := NewMockIPGenerator(ctrl) + + ctx := context.Background() + scanRange := newScanRange() + ipgen.EXPECT().IPs(ctx, scanRange).Return(nil, tt.ipsError).AnyTimes() + + portgen := NewMockPortGenerator(ctrl) + portgen.EXPECT().Ports(ctx, scanRange).Return(nil, tt.portsError).AnyTimes() + + reqgen := NewIPPortGenerator(ipgen, portgen) + _, err := reqgen.GenerateRequests(ctx, scanRange) + require.Error(t, err) + }() + waitDone(t, done) + }) + } +} + func TestIPRequestGenerator(t *testing.T) { t.Parallel()