diff --git a/datachannel.go b/datachannel.go index e975a04f7b..6bdee24221 100644 --- a/datachannel.go +++ b/datachannel.go @@ -40,6 +40,7 @@ type DataChannel struct { readyState atomic.Value // DataChannelState bufferedAmountLowThreshold uint64 detachCalled bool + readLoopActive chan struct{} // The binaryType represents attribute MUST, on getting, return the value to // which it was last set. On setting, if the new value is either the string @@ -327,6 +328,7 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread defer d.mu.Unlock() if !d.api.settingEngine.detach.DataChannels { + d.readLoopActive = make(chan struct{}) go d.readLoop() } } @@ -350,6 +352,7 @@ func (d *DataChannel) onError(err error) { } func (d *DataChannel) readLoop() { + defer close(d.readLoopActive) buffer := make([]byte, dataChannelBufferSize) for { n, isString, err := d.dataChannel.ReadDataChannel(buffer) @@ -465,6 +468,12 @@ func (d *DataChannel) Close() error { return d.dataChannel.Close() } +func (d *DataChannel) waitForReadLoopDone() { + if d.readLoopActive != nil { + <-d.readLoopActive + } +} + // Label represents a label that can be used to distinguish this // DataChannel object from other DataChannel objects. Scripts are // allowed to create multiple DataChannel objects with the same label. diff --git a/peerconnection.go b/peerconnection.go index 3ae8b732c6..1eb6a7aed2 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -56,6 +56,7 @@ type PeerConnection struct { idpLoginURL *string isClosed *atomicBool + isClosedDone chan struct{} isNegotiationNeeded *atomicBool updateNegotiationNeededFlagOnEmptyChain *atomicBool @@ -116,6 +117,7 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, ICECandidatePoolSize: 0, }, isClosed: &atomicBool{}, + isClosedDone: make(chan struct{}), isNegotiationNeeded: &atomicBool{}, updateNegotiationNeededFlagOnEmptyChain: &atomicBool{}, lastOffer: "", @@ -2044,14 +2046,31 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes return pc.dtlsTransport.WriteRTCP(pkts) } -// Close ends the PeerConnection +// Close ends the PeerConnection. +// It will make a best effort to wait for all underlying goroutines it spwaned to finish, +// except for cases that would cause deadlocks with itself. func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) if pc.isClosed.swap(true) { + // someone else got here first but may still be closing (e.g. via DTLS close_notify) + <-pc.isClosedDone return nil } + defer close(pc.isClosedDone) + // Try closing everything and collect the errors + // Shutdown strategy: + // 1. Close all data channels. + // 2. All Conn close by closing their underlying Conn. + // 3. A Mux stops this chain. It won't close the underlying + // Conn if one of the endpoints is closed down. To + // continue the chain the Mux has to be closed. + pc.sctpTransport.lock.Lock() + closeErrs := make([]error, 0, 4+len(pc.sctpTransport.dataChannels)) + pc.sctpTransport.lock.Unlock() + + // canon steps // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) pc.signalingState.Set(SignalingStateClosed) @@ -2061,7 +2080,6 @@ func (pc *PeerConnection) Close() error { // 2. A Mux stops this chain. It won't close the underlying // Conn if one of the endpoints is closed down. To // continue the chain the Mux has to be closed. - closeErrs := make([]error, 4) closeErrs = append(closeErrs, pc.api.interceptor.Close()) @@ -2088,7 +2106,6 @@ func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7) closeErrs = append(closeErrs, pc.dtlsTransport.Stop()) - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) if pc.iceTransport != nil { closeErrs = append(closeErrs, pc.iceTransport.Stop()) @@ -2097,6 +2114,14 @@ func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + // non-canon steps + pc.sctpTransport.lock.Lock() + for _, d := range pc.sctpTransport.dataChannels { + closeErrs = append(closeErrs, d.Close()) + d.waitForReadLoopDone() + } + pc.sctpTransport.lock.Unlock() + return util.FlattenErrs(closeErrs) } @@ -2268,8 +2293,11 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re } pc.dtlsTransport.internalOnCloseHandler = func() { - pc.log.Info("Closing PeerConnection from DTLS CloseNotify") + if pc.isClosed.get() { + return + } + pc.log.Info("Closing PeerConnection from DTLS CloseNotify") go func() { if pcClosErr := pc.Close(); pcClosErr != nil { pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr) diff --git a/peerconnection_close_test.go b/peerconnection_close_test.go index 5360d701fc..ac3b23f704 100644 --- a/peerconnection_close_test.go +++ b/peerconnection_close_test.go @@ -7,6 +7,8 @@ package webrtc import ( + "runtime" + "strings" "testing" "time" @@ -179,3 +181,103 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) { t.Error("pcOffer.Close() Timeout") } } + +func TestPeerConnection_CloseWithIncomingMessages(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + report := CheckRoutinesIntolerant(t) + defer report() + + pcOffer, pcAnswer, err := newPair() + if err != nil { + t.Fatal(err) + } + + var dcAnswer *DataChannel + answerDataChannelOpened := make(chan struct{}) + pcAnswer.OnDataChannel(func(d *DataChannel) { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.Label() != "data" { + return + } + dcAnswer = d + close(answerDataChannelOpened) + }) + + dcOffer, err := pcOffer.CreateDataChannel("data", nil) + if err != nil { + t.Fatal(err) + } + + offerDataChannelOpened := make(chan struct{}) + dcOffer.OnOpen(func() { + close(offerDataChannelOpened) + }) + + err = signalPair(pcOffer, pcAnswer) + if err != nil { + t.Fatal(err) + } + + <-offerDataChannelOpened + <-answerDataChannelOpened + + msgNum := 0 + dcOffer.OnMessage(func(_ DataChannelMessage) { + t.Log("msg", msgNum) + msgNum++ + }) + + // send 50 messages, then close pcOffer, and then send another 50 + for i := 0; i < 100; i++ { + if i == 50 { + err = pcOffer.Close() + if err != nil { + t.Fatal(err) + } + } + _ = dcAnswer.Send([]byte("hello!")) + } + + err = pcAnswer.Close() + if err != nil { + t.Fatal(err) + } +} + +// CheckRoutinesIntolerant is used to check for leaked go-routines. +// It differs from test.CheckRoutines in that it won't wait at all +// for lingering goroutines. This is helpful for tests that need +// to ensure clean closure of resources. +func CheckRoutinesIntolerant(t *testing.T) func() { + return func() { + routines := getRoutines() + if len(routines) == 0 { + return + } + t.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint + } +} + +func getRoutines() []string { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + return filterRoutines(strings.Split(string(buf), "\n\n")) +} + +func filterRoutines(routines []string) []string { + result := []string{} + for _, stack := range routines { + if stack == "" || // Empty + strings.Contains(stack, "testing.Main(") || // Tests + strings.Contains(stack, "testing.(*T).Run(") || // Test run + strings.Contains(stack, "getRoutines(") { // This routine + continue + } + result = append(result, stack) + } + return result +}