Skip to content

Commit

Permalink
feat: support websocket sink (#2985)
Browse files Browse the repository at this point in the history
Signed-off-by: Song Gao <[email protected]>
  • Loading branch information
Yisaer committed Jul 5, 2024
1 parent 135bc30 commit d237422
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 20 deletions.
2 changes: 1 addition & 1 deletion internal/binder/io/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func init() {
modules.RegisterSink("memory", func() api.Sink { return memory.GetSink() })
modules.RegisterSink("neuron", neuron.GetSink)
modules.RegisterSink("file", file.GetSink)
// modules.RegisterSink("websocket", func() api.Sink { return &websocket.WebSocketSink{} })
modules.RegisterSink("websocket", func() api.Sink { return websocket.GetSink() })

modules.RegisterLookupSource("memory", memory.GetLookupSource)
// modules.RegisterLookupSource("httppull", func() api.LookupSource { return http.GetLookUpSource() })
Expand Down
1 change: 1 addition & 0 deletions internal/io/http/httpserver/data_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

type GlobalServerManager struct {
sync.RWMutex
instanceID int
endpoint map[string]string
server *http.Server
router *mux.Router
Expand Down
6 changes: 4 additions & 2 deletions internal/io/http/httpserver/websocketConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

type WebsocketConnection struct {
RecvTopic string
SendTopic string
cfg *connectionCfg
}

Expand All @@ -46,12 +47,13 @@ func createWebsocketServerConnection(ctx api.StreamContext, props map[string]any
if err := cast.MapToStruct(props, cfg); err != nil {
return nil, err
}
recvTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource)
rTopic, sTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource)
if err != nil {
return nil, err
}
return &WebsocketConnection{
RecvTopic: recvTopic,
RecvTopic: rTopic,
SendTopic: sTopic,
cfg: cfg,
}, nil
}
77 changes: 67 additions & 10 deletions internal/io/http/httpserver/websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,20 @@ const (
WebsocketTopicPrefix = "$$websocket/"
)

func recvTopic(endpoint string) string {
return fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint)
}

func sendTopic(endpoint string) string {
return fmt.Sprintf("send/%s/%s", WebsocketTopicPrefix, endpoint)
}

type websocketEndpointContext struct {
wg *sync.WaitGroup
conns map[*websocket.Conn]context.CancelFunc
}

func RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, error) {
func RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, string, error) {
return manager.RegisterWebSocketEndpoint(ctx, endpoint)
}

Expand All @@ -49,14 +57,56 @@ func UnRegisterWebSocketEndpoint(endpoint string) {
}
}

func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string, c *websocket.Conn, wg *sync.WaitGroup) {
func (m *GlobalServerManager) handleProcess(ctx api.StreamContext, endpoint string, instanceID int, c *websocket.Conn, cancel context.CancelFunc, parWg *sync.WaitGroup) {
defer func() {
m.CloseEndpointConnection(endpoint, c)
conf.Log.Infof("websocket endpoint %v stop recvProcess", endpoint)
parWg.Done()
}()
subWg := &sync.WaitGroup{}
subWg.Add(2)
go m.recvProcess(ctx, endpoint, c, cancel, subWg)
go m.sendProcess(ctx, endpoint, instanceID, c, cancel, subWg)
subWg.Wait()
}

func (m *GlobalServerManager) sendProcess(ctx api.StreamContext, endpoint string, instanceID int, c *websocket.Conn, cancel context.CancelFunc, wg *sync.WaitGroup) {
conf.Log.Infof("websocket endpoint %v start sendProcess", endpoint)
topic := sendTopic(endpoint)
sourceID := fmt.Sprintf("ws/send/%v", instanceID)
defer func() {
pubsub.CloseSourceConsumerChannel(topic, sourceID)
cancel()
c.Close()
wg.Done()
conf.Log.Infof("websocket send endpoint %v stop sendProcess", endpoint)
}()
ch := pubsub.CreateSub(topic, nil, sourceID, 1024)
for {
select {
case <-ctx.Done():
return
case d := <-ch:
data := d.([]byte)
if err := c.WriteMessage(websocket.TextMessage, data); err != nil {
if websocket.IsCloseError(err) || strings.Contains(err.Error(), "close") {
conf.Log.Infof("websocket endpoint %s connection get closed: %v", endpoint, err)
return
}
conf.Log.Warnf("websocket endpoint %v send data meet error: %v", endpoint, err)
}
}
}
}

func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string, c *websocket.Conn, cancel context.CancelFunc, wg *sync.WaitGroup) {
defer func() {
cancel()
c.Close()
wg.Done()
conf.Log.Infof("websocket recv endpoint %v stop recvProcess", endpoint)
}()
conf.Log.Infof("websocket endpoint %v start recvProcess", endpoint)
topic := fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint)
topic := recvTopic(endpoint)
for {
select {
case <-ctx.Done():
Expand All @@ -81,12 +131,13 @@ func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string
}
}

func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, error) {
func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, string, error) {
conf.Log.Infof("websocket endpoint %v register", endpoint)
m.Lock()
defer m.Unlock()
recvTopic := fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint)
pubsub.CreatePub(recvTopic)
rTopic := recvTopic(endpoint)
sTopic := sendTopic(endpoint)
pubsub.CreatePub(rTopic)
m.router.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
c, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
Expand All @@ -95,16 +146,16 @@ func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, e
}
subCtx, cancel := ctx.WithCancel()
wg := m.AddEndpointConnection(endpoint, c, cancel)
go m.handleProcess(subCtx, endpoint, m.FetchInstanceID(), c, cancel, wg)
conf.Log.Infof("websocket endpint %v create connection", endpoint)
wg.Add(1)
go m.recvProcess(subCtx, endpoint, c, wg)
})
conf.Log.Infof("websocker endpoint %v registered success", endpoint)
return recvTopic, nil
return rTopic, sTopic, nil
}

func (m *GlobalServerManager) UnRegisterWebSocketEndpoint(endpoint string) *websocketEndpointContext {
conf.Log.Infof("websocket endpoint %v unregister", endpoint)
pubsub.RemovePub(recvTopic(endpoint))
m.Lock()
defer m.Unlock()
wctx, ok := m.websocketEndpoint[endpoint]
Expand Down Expand Up @@ -140,6 +191,7 @@ func (m *GlobalServerManager) AddEndpointConnection(endpoint string, c *websocke
return wctx.wg
}
wg := &sync.WaitGroup{}
wg.Add(1)
m.websocketEndpoint[endpoint] = &websocketEndpointContext{
wg: wg,
conns: map[*websocket.Conn]context.CancelFunc{
Expand All @@ -149,6 +201,11 @@ func (m *GlobalServerManager) AddEndpointConnection(endpoint string, c *websocke
return wg
}

func (m *GlobalServerManager) FetchInstanceID() int {
m.instanceID++
return m.instanceID
}

// getEndpointConnections only for unit test
func (m *GlobalServerManager) getEndpointConnections(endpoint string) *websocketEndpointContext {
m.RLock()
Expand Down
47 changes: 42 additions & 5 deletions internal/io/http/httpserver/websocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package httpserver

import (
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
Expand All @@ -32,14 +33,16 @@ func TestWebsocketServerRecvData(t *testing.T) {
defer ShutDown()
ctx := mockContext.NewMockContext("1", "2")
endpint := "/e1"
recvTopic, err := RegisterWebSocketEndpoint(ctx, endpint)
rTopic, _, err := RegisterWebSocketEndpoint(ctx, endpint)
require.NoError(t, err)
subCh := pubsub.CreateSub(recvTopic, nil, "test", 1024)
defer pubsub.CloseSourceConsumerChannel(recvTopic, "test")
subCh := pubsub.CreateSub(rTopic, nil, "test", 1024)
defer pubsub.CloseSourceConsumerChannel(rTopic, "test")
conn, err := testx.CreateWebsocketClient(ip, port, endpint)
require.NoError(t, err)
defer conn.Close()
data := []byte("123")
// wait goroutine process started
time.Sleep(10 * time.Millisecond)
require.NoError(t, conn.WriteMessage(websocket.TextMessage, data))
recvData := <-subCh
require.Equal(t, data, recvData.([]byte))
Expand All @@ -53,7 +56,7 @@ func TestWebsocketServerRecvDataCancel(t *testing.T) {
defer ShutDown()
ctx := mockContext.NewMockContext("1", "2")
endpint := "/e1"
_, err := RegisterWebSocketEndpoint(ctx, endpint)
_, _, err := RegisterWebSocketEndpoint(ctx, endpint)
require.NoError(t, err)
UnRegisterWebSocketEndpoint(endpint)
}
Expand All @@ -65,13 +68,47 @@ func TestWebsocketServerRecvDataOther(t *testing.T) {
defer ShutDown()
ctx := mockContext.NewMockContext("1", "2")
endpint := "/e1"
_, err := RegisterWebSocketEndpoint(ctx, endpint)
_, _, err := RegisterWebSocketEndpoint(ctx, endpint)
require.NoError(t, err)
conn, err := testx.CreateWebsocketClient(ip, port, endpint)
require.NoError(t, err)
defer conn.Close()
// wait goroutine process started
time.Sleep(10 * time.Millisecond)
require.NoError(t, conn.WriteMessage(websocket.PingMessage, []byte("123")))
require.NoError(t, conn.WriteMessage(websocket.CloseMessage, []byte("123")))
wctx := manager.getEndpointConnections(endpint)
wctx.wg.Wait()
require.Equal(t, 0, len(wctx.conns))
}

func TestWebsocketServerSendData(t *testing.T) {
endpoint := "/e1"
topic := sendTopic(endpoint)
pubsub.CreatePub(topic)
ip := "127.0.0.1"
port := 10085
InitGlobalServerManager(ip, port, nil)
defer ShutDown()
ctx := mockContext.NewMockContext("1", "2")
_, sTopic, err := RegisterWebSocketEndpoint(ctx, endpoint)
require.NoError(t, err)
require.Equal(t, topic, sTopic)
conn, err := testx.CreateWebsocketClient(ip, port, endpoint)
require.NoError(t, err)
defer conn.Close()
// wait goroutine process started
time.Sleep(10 * time.Millisecond)
assertCh := make(chan struct{})
go func() {
msgTyp, data, err := conn.ReadMessage()
require.NoError(t, err)
require.Equal(t, websocket.TextMessage, msgTyp)
require.Equal(t, []byte("123"), data)
assertCh <- struct{}{}
}()
time.Sleep(10 * time.Millisecond)
pubsub.ProduceAny(ctx, topic, []byte("123"))
<-assertCh
UnRegisterWebSocketEndpoint(endpoint)
}
83 changes: 83 additions & 0 deletions internal/io/websocket/websocket_sink.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package websocket

import (
"fmt"
"strings"

"github.com/lf-edge/ekuiper/contract/v2/api"
"github.com/lf-edge/ekuiper/v2/internal/io/http/httpserver"
"github.com/lf-edge/ekuiper/v2/internal/io/memory/pubsub"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
"github.com/lf-edge/ekuiper/v2/pkg/connection"
)

type WebsocketSink struct {
cfg *WebsocketConfig
props map[string]any
topic string
}

func (w *WebsocketSink) Provision(ctx api.StreamContext, configs map[string]any) error {
cfg := &WebsocketConfig{}
if err := cast.MapToStruct(configs, cfg); err != nil {
return err
}
if !strings.HasPrefix(cfg.Endpoint, "/") {
return fmt.Errorf("websocket endpoint should start with /")
}
w.cfg = cfg
w.props = configs
return nil
}

func (w *WebsocketSink) Close(ctx api.StreamContext) error {
pubsub.RemovePub(w.topic)
return connection.DetachConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), w.props)
}

func (w *WebsocketSink) Connect(ctx api.StreamContext) error {
conn, err := connection.FetchConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), "websocket", w.props)
if err != nil {
return err
}
c, ok := conn.(*httpserver.WebsocketConnection)
if !ok {
return fmt.Errorf("should use websocket connection")
}
w.topic = c.SendTopic
pubsub.CreatePub(w.topic)
return nil
}

func (w *WebsocketSink) Collect(ctx api.StreamContext, item api.RawTuple) error {
return w.collect(ctx, item.Raw())
}

func (w *WebsocketSink) collect(ctx api.StreamContext, data []byte) error {
pubsub.ProduceAny(ctx, w.topic, data)
return nil
}

func GetSink() api.Sink {
return &WebsocketSink{}
}

var _ api.BytesCollector = &WebsocketSink{}

func buildWebsocketEpID(endpoint string) string {
return fmt.Sprintf("$$ws/%s", endpoint)
}
Loading

0 comments on commit d237422

Please sign in to comment.