From d23742287106ed3e44e5f40b73b040c0c2eb2b43 Mon Sep 17 00:00:00 2001 From: Song Gao Date: Fri, 5 Jul 2024 14:21:56 +0800 Subject: [PATCH] feat: support websocket sink (#2985) Signed-off-by: Song Gao --- internal/binder/io/builtin.go | 2 +- internal/io/http/httpserver/data_server.go | 1 + internal/io/http/httpserver/websocketConn.go | 6 +- .../io/http/httpserver/websocket_server.go | 77 ++++++++++++++--- .../http/httpserver/websocket_server_test.go | 47 +++++++++-- internal/io/websocket/websocket_sink.go | 83 +++++++++++++++++++ internal/io/websocket/websocket_sink_test.go | 68 +++++++++++++++ internal/io/websocket/websocket_source.go | 4 +- 8 files changed, 268 insertions(+), 20 deletions(-) create mode 100644 internal/io/websocket/websocket_sink.go create mode 100644 internal/io/websocket/websocket_sink_test.go diff --git a/internal/binder/io/builtin.go b/internal/binder/io/builtin.go index e38ed8c7d..eb7cc1137 100644 --- a/internal/binder/io/builtin.go +++ b/internal/binder/io/builtin.go @@ -47,7 +47,7 @@ func init() { modules.RegisterSink("memory", func() api.Sink { return memory.GetSink() }) modules.RegisterSink("neuron", neuron.GetSink) modules.RegisterSink("file", file.GetSink) - // modules.RegisterSink("websocket", func() api.Sink { return &websocket.WebSocketSink{} }) + modules.RegisterSink("websocket", func() api.Sink { return websocket.GetSink() }) modules.RegisterLookupSource("memory", memory.GetLookupSource) // modules.RegisterLookupSource("httppull", func() api.LookupSource { return http.GetLookUpSource() }) diff --git a/internal/io/http/httpserver/data_server.go b/internal/io/http/httpserver/data_server.go index 8db68fc83..d2b22b6a2 100644 --- a/internal/io/http/httpserver/data_server.go +++ b/internal/io/http/httpserver/data_server.go @@ -33,6 +33,7 @@ import ( type GlobalServerManager struct { sync.RWMutex + instanceID int endpoint map[string]string server *http.Server router *mux.Router diff --git a/internal/io/http/httpserver/websocketConn.go b/internal/io/http/httpserver/websocketConn.go index 22c98c518..c106b9745 100644 --- a/internal/io/http/httpserver/websocketConn.go +++ b/internal/io/http/httpserver/websocketConn.go @@ -22,6 +22,7 @@ import ( type WebsocketConnection struct { RecvTopic string + SendTopic string cfg *connectionCfg } @@ -46,12 +47,13 @@ func createWebsocketServerConnection(ctx api.StreamContext, props map[string]any if err := cast.MapToStruct(props, cfg); err != nil { return nil, err } - recvTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource) + rTopic, sTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource) if err != nil { return nil, err } return &WebsocketConnection{ - RecvTopic: recvTopic, + RecvTopic: rTopic, + SendTopic: sTopic, cfg: cfg, }, nil } diff --git a/internal/io/http/httpserver/websocket_server.go b/internal/io/http/httpserver/websocket_server.go index d2137bdf5..fbef1d61c 100644 --- a/internal/io/http/httpserver/websocket_server.go +++ b/internal/io/http/httpserver/websocket_server.go @@ -32,12 +32,20 @@ const ( WebsocketTopicPrefix = "$$websocket/" ) +func recvTopic(endpoint string) string { + return fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint) +} + +func sendTopic(endpoint string) string { + return fmt.Sprintf("send/%s/%s", WebsocketTopicPrefix, endpoint) +} + type websocketEndpointContext struct { wg *sync.WaitGroup conns map[*websocket.Conn]context.CancelFunc } -func RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, error) { +func RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, string, error) { return manager.RegisterWebSocketEndpoint(ctx, endpoint) } @@ -49,14 +57,56 @@ func UnRegisterWebSocketEndpoint(endpoint string) { } } -func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string, c *websocket.Conn, wg *sync.WaitGroup) { +func (m *GlobalServerManager) handleProcess(ctx api.StreamContext, endpoint string, instanceID int, c *websocket.Conn, cancel context.CancelFunc, parWg *sync.WaitGroup) { defer func() { m.CloseEndpointConnection(endpoint, c) - conf.Log.Infof("websocket endpoint %v stop recvProcess", endpoint) + parWg.Done() + }() + subWg := &sync.WaitGroup{} + subWg.Add(2) + go m.recvProcess(ctx, endpoint, c, cancel, subWg) + go m.sendProcess(ctx, endpoint, instanceID, c, cancel, subWg) + subWg.Wait() +} + +func (m *GlobalServerManager) sendProcess(ctx api.StreamContext, endpoint string, instanceID int, c *websocket.Conn, cancel context.CancelFunc, wg *sync.WaitGroup) { + conf.Log.Infof("websocket endpoint %v start sendProcess", endpoint) + topic := sendTopic(endpoint) + sourceID := fmt.Sprintf("ws/send/%v", instanceID) + defer func() { + pubsub.CloseSourceConsumerChannel(topic, sourceID) + cancel() + c.Close() + wg.Done() + conf.Log.Infof("websocket send endpoint %v stop sendProcess", endpoint) + }() + ch := pubsub.CreateSub(topic, nil, sourceID, 1024) + for { + select { + case <-ctx.Done(): + return + case d := <-ch: + data := d.([]byte) + if err := c.WriteMessage(websocket.TextMessage, data); err != nil { + if websocket.IsCloseError(err) || strings.Contains(err.Error(), "close") { + conf.Log.Infof("websocket endpoint %s connection get closed: %v", endpoint, err) + return + } + conf.Log.Warnf("websocket endpoint %v send data meet error: %v", endpoint, err) + } + } + } +} + +func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string, c *websocket.Conn, cancel context.CancelFunc, wg *sync.WaitGroup) { + defer func() { + cancel() + c.Close() wg.Done() + conf.Log.Infof("websocket recv endpoint %v stop recvProcess", endpoint) }() conf.Log.Infof("websocket endpoint %v start recvProcess", endpoint) - topic := fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint) + topic := recvTopic(endpoint) for { select { case <-ctx.Done(): @@ -81,12 +131,13 @@ func (m *GlobalServerManager) recvProcess(ctx api.StreamContext, endpoint string } } -func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, error) { +func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, endpoint string) (string, string, error) { conf.Log.Infof("websocket endpoint %v register", endpoint) m.Lock() defer m.Unlock() - recvTopic := fmt.Sprintf("recv/%s/%s", WebsocketTopicPrefix, endpoint) - pubsub.CreatePub(recvTopic) + rTopic := recvTopic(endpoint) + sTopic := sendTopic(endpoint) + pubsub.CreatePub(rTopic) m.router.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) { c, err := m.upgrader.Upgrade(w, r, nil) if err != nil { @@ -95,16 +146,16 @@ func (m *GlobalServerManager) RegisterWebSocketEndpoint(ctx api.StreamContext, e } subCtx, cancel := ctx.WithCancel() wg := m.AddEndpointConnection(endpoint, c, cancel) + go m.handleProcess(subCtx, endpoint, m.FetchInstanceID(), c, cancel, wg) conf.Log.Infof("websocket endpint %v create connection", endpoint) - wg.Add(1) - go m.recvProcess(subCtx, endpoint, c, wg) }) conf.Log.Infof("websocker endpoint %v registered success", endpoint) - return recvTopic, nil + return rTopic, sTopic, nil } func (m *GlobalServerManager) UnRegisterWebSocketEndpoint(endpoint string) *websocketEndpointContext { conf.Log.Infof("websocket endpoint %v unregister", endpoint) + pubsub.RemovePub(recvTopic(endpoint)) m.Lock() defer m.Unlock() wctx, ok := m.websocketEndpoint[endpoint] @@ -140,6 +191,7 @@ func (m *GlobalServerManager) AddEndpointConnection(endpoint string, c *websocke return wctx.wg } wg := &sync.WaitGroup{} + wg.Add(1) m.websocketEndpoint[endpoint] = &websocketEndpointContext{ wg: wg, conns: map[*websocket.Conn]context.CancelFunc{ @@ -149,6 +201,11 @@ func (m *GlobalServerManager) AddEndpointConnection(endpoint string, c *websocke return wg } +func (m *GlobalServerManager) FetchInstanceID() int { + m.instanceID++ + return m.instanceID +} + // getEndpointConnections only for unit test func (m *GlobalServerManager) getEndpointConnections(endpoint string) *websocketEndpointContext { m.RLock() diff --git a/internal/io/http/httpserver/websocket_server_test.go b/internal/io/http/httpserver/websocket_server_test.go index 5117350f9..02f47afd1 100644 --- a/internal/io/http/httpserver/websocket_server_test.go +++ b/internal/io/http/httpserver/websocket_server_test.go @@ -16,6 +16,7 @@ package httpserver import ( "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -32,14 +33,16 @@ func TestWebsocketServerRecvData(t *testing.T) { defer ShutDown() ctx := mockContext.NewMockContext("1", "2") endpint := "/e1" - recvTopic, err := RegisterWebSocketEndpoint(ctx, endpint) + rTopic, _, err := RegisterWebSocketEndpoint(ctx, endpint) require.NoError(t, err) - subCh := pubsub.CreateSub(recvTopic, nil, "test", 1024) - defer pubsub.CloseSourceConsumerChannel(recvTopic, "test") + subCh := pubsub.CreateSub(rTopic, nil, "test", 1024) + defer pubsub.CloseSourceConsumerChannel(rTopic, "test") conn, err := testx.CreateWebsocketClient(ip, port, endpint) require.NoError(t, err) defer conn.Close() data := []byte("123") + // wait goroutine process started + time.Sleep(10 * time.Millisecond) require.NoError(t, conn.WriteMessage(websocket.TextMessage, data)) recvData := <-subCh require.Equal(t, data, recvData.([]byte)) @@ -53,7 +56,7 @@ func TestWebsocketServerRecvDataCancel(t *testing.T) { defer ShutDown() ctx := mockContext.NewMockContext("1", "2") endpint := "/e1" - _, err := RegisterWebSocketEndpoint(ctx, endpint) + _, _, err := RegisterWebSocketEndpoint(ctx, endpint) require.NoError(t, err) UnRegisterWebSocketEndpoint(endpint) } @@ -65,13 +68,47 @@ func TestWebsocketServerRecvDataOther(t *testing.T) { defer ShutDown() ctx := mockContext.NewMockContext("1", "2") endpint := "/e1" - _, err := RegisterWebSocketEndpoint(ctx, endpint) + _, _, err := RegisterWebSocketEndpoint(ctx, endpint) require.NoError(t, err) conn, err := testx.CreateWebsocketClient(ip, port, endpint) require.NoError(t, err) + defer conn.Close() + // wait goroutine process started + time.Sleep(10 * time.Millisecond) require.NoError(t, conn.WriteMessage(websocket.PingMessage, []byte("123"))) require.NoError(t, conn.WriteMessage(websocket.CloseMessage, []byte("123"))) wctx := manager.getEndpointConnections(endpint) wctx.wg.Wait() require.Equal(t, 0, len(wctx.conns)) } + +func TestWebsocketServerSendData(t *testing.T) { + endpoint := "/e1" + topic := sendTopic(endpoint) + pubsub.CreatePub(topic) + ip := "127.0.0.1" + port := 10085 + InitGlobalServerManager(ip, port, nil) + defer ShutDown() + ctx := mockContext.NewMockContext("1", "2") + _, sTopic, err := RegisterWebSocketEndpoint(ctx, endpoint) + require.NoError(t, err) + require.Equal(t, topic, sTopic) + conn, err := testx.CreateWebsocketClient(ip, port, endpoint) + require.NoError(t, err) + defer conn.Close() + // wait goroutine process started + time.Sleep(10 * time.Millisecond) + assertCh := make(chan struct{}) + go func() { + msgTyp, data, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.TextMessage, msgTyp) + require.Equal(t, []byte("123"), data) + assertCh <- struct{}{} + }() + time.Sleep(10 * time.Millisecond) + pubsub.ProduceAny(ctx, topic, []byte("123")) + <-assertCh + UnRegisterWebSocketEndpoint(endpoint) +} diff --git a/internal/io/websocket/websocket_sink.go b/internal/io/websocket/websocket_sink.go new file mode 100644 index 000000000..613a21d50 --- /dev/null +++ b/internal/io/websocket/websocket_sink.go @@ -0,0 +1,83 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package websocket + +import ( + "fmt" + "strings" + + "github.com/lf-edge/ekuiper/contract/v2/api" + "github.com/lf-edge/ekuiper/v2/internal/io/http/httpserver" + "github.com/lf-edge/ekuiper/v2/internal/io/memory/pubsub" + "github.com/lf-edge/ekuiper/v2/pkg/cast" + "github.com/lf-edge/ekuiper/v2/pkg/connection" +) + +type WebsocketSink struct { + cfg *WebsocketConfig + props map[string]any + topic string +} + +func (w *WebsocketSink) Provision(ctx api.StreamContext, configs map[string]any) error { + cfg := &WebsocketConfig{} + if err := cast.MapToStruct(configs, cfg); err != nil { + return err + } + if !strings.HasPrefix(cfg.Endpoint, "/") { + return fmt.Errorf("websocket endpoint should start with /") + } + w.cfg = cfg + w.props = configs + return nil +} + +func (w *WebsocketSink) Close(ctx api.StreamContext) error { + pubsub.RemovePub(w.topic) + return connection.DetachConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), w.props) +} + +func (w *WebsocketSink) Connect(ctx api.StreamContext) error { + conn, err := connection.FetchConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), "websocket", w.props) + if err != nil { + return err + } + c, ok := conn.(*httpserver.WebsocketConnection) + if !ok { + return fmt.Errorf("should use websocket connection") + } + w.topic = c.SendTopic + pubsub.CreatePub(w.topic) + return nil +} + +func (w *WebsocketSink) Collect(ctx api.StreamContext, item api.RawTuple) error { + return w.collect(ctx, item.Raw()) +} + +func (w *WebsocketSink) collect(ctx api.StreamContext, data []byte) error { + pubsub.ProduceAny(ctx, w.topic, data) + return nil +} + +func GetSink() api.Sink { + return &WebsocketSink{} +} + +var _ api.BytesCollector = &WebsocketSink{} + +func buildWebsocketEpID(endpoint string) string { + return fmt.Sprintf("$$ws/%s", endpoint) +} diff --git a/internal/io/websocket/websocket_sink_test.go b/internal/io/websocket/websocket_sink_test.go new file mode 100644 index 000000000..55b9cd5ed --- /dev/null +++ b/internal/io/websocket/websocket_sink_test.go @@ -0,0 +1,68 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package websocket + +import ( + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + + "github.com/lf-edge/ekuiper/v2/internal/io/http/httpserver" + "github.com/lf-edge/ekuiper/v2/internal/testx" + "github.com/lf-edge/ekuiper/v2/pkg/connection" + mockContext "github.com/lf-edge/ekuiper/v2/pkg/mock/context" + "github.com/lf-edge/ekuiper/v2/pkg/modules" +) + +func init() { + modules.RegisterConnection("websocket", httpserver.CreateWebsocketConnection) +} + +func TestWebsocketSink(t *testing.T) { + connection.InitConnectionManager4Test() + ip := "127.0.0.1" + port := 10081 + endpoint := "/e1" + httpserver.InitGlobalServerManager(ip, port, nil) + defer httpserver.ShutDown() + ctx := mockContext.NewMockContext("1", "2") + ws := &WebsocketSink{} + props := map[string]any{ + "datasource": endpoint, + } + require.Error(t, ws.Provision(ctx, map[string]any{ + "datasource": "", + })) + require.NoError(t, ws.Provision(ctx, props)) + require.NoError(t, ws.Connect(ctx)) + expData := []byte("123") + assertCh := make(chan struct{}) + conn, err := testx.CreateWebsocketClient(ip, port, endpoint) + require.NoError(t, err) + go func() { + msgTyp, data, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.TextMessage, msgTyp) + require.Equal(t, expData, data) + assertCh <- struct{}{} + }() + // wait goroutine start + time.Sleep(10 * time.Millisecond) + require.NoError(t, ws.collect(ctx, expData)) + <-assertCh + ws.Close(ctx) +} diff --git a/internal/io/websocket/websocket_source.go b/internal/io/websocket/websocket_source.go index e2f689cf3..5dc6d2a18 100644 --- a/internal/io/websocket/websocket_source.go +++ b/internal/io/websocket/websocket_source.go @@ -53,11 +53,11 @@ func (w *WebsocketSource) Provision(ctx api.StreamContext, configs map[string]an func (w *WebsocketSource) Close(ctx api.StreamContext) error { pubsub.CloseSourceConsumerChannel(w.topic, w.sourceID) - return connection.DetachConnection(ctx, w.cfg.Endpoint, w.props) + return connection.DetachConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), w.props) } func (w *WebsocketSource) Connect(ctx api.StreamContext) error { - conn, err := connection.FetchConnection(ctx, w.cfg.Endpoint, "websocket", w.props) + conn, err := connection.FetchConnection(ctx, buildWebsocketEpID(w.cfg.Endpoint), "websocket", w.props) if err != nil { return err }