diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index d89332e255..37731b6f37 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -97,7 +97,9 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { finishOnDeinit: closeOnDeinit, delegate: .init(handler: handler) ) + handler.sink = writer.sink + handler.writer = writer.writer try channel.pipeline.syncOperations.addHandler(handler) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index 59fad5e3e1..daa35fd63e 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -37,6 +37,18 @@ internal final class NIOAsyncChannelOutboundWriterHandler @usableFromInline var sink: Sink? + /// The writer of the ``NIOAsyncWriter``. + /// + /// The reference is retained until `channelActive` is fired. This avoids situations + /// where `deinit` is called on the unfinished writer because the `Channel` was never returned + /// to the caller (e.g. because a connect failed or or happy-eyeballs created multiple + /// channels). + /// + /// Effectively `channelActive` is used at the point in time at which NIO cedes ownership of + /// the writer to the caller. + @usableFromInline + var writer: Writer? + /// The channel handler context. @usableFromInline var context: ChannelHandlerContext? @@ -126,6 +138,14 @@ internal final class NIOAsyncChannelOutboundWriterHandler func handlerRemoved(context: ChannelHandlerContext) { self.context = nil self.sink?.finish() + self.writer = nil + } + + @inlinable + func channelActive(context: ChannelHandlerContext) { + // Drop the writer ref, the caller is responsible for it now. + self.writer = nil + context.fireChannelActive() } @inlinable diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 1fb1bdd810..91941cb8ae 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -542,9 +542,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .unknown ) - await XCTAssertThrowsError( + await XCTAssertThrowsError { try await failedProtocolNegotiation.get() - ) + } // Let's check that we can still open a new connection let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( @@ -575,6 +575,26 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + func testClientBootstrap_connectFails() async throws { + // Beyond verifying the connect throws, this test allows us to check that 'NIOAsyncChannel' + // doesn't crash on deinit when we never return it to the user. + await XCTAssertThrowsError { + try await ClientBootstrap( + group: .singletonMultiThreadedEventLoopGroup + ).connect(unixDomainSocketPath: "testClientBootstrapConnectFails") { channel in + return channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + } + } + // MARK: Datagram Bootstrap func testDatagramBootstrap_withAsyncChannel_andHostPort() async throws { @@ -663,6 +683,26 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + func testDatagramBootstrap_connectFails() async throws { + // Beyond verifying the connect throws, this test allows us to check that 'NIOAsyncChannel' + // doesn't crash on deinit when we never return it to the user. + await XCTAssertThrowsError { + try await DatagramBootstrap( + group: .singletonMultiThreadedEventLoopGroup + ).connect(unixDomainSocketPath: "testDatagramBootstrapConnectFails") { channel in + return channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: AddressedEnvelope.self, + outboundType: AddressedEnvelope.self + ) + ) + } + } + } + } + // MARK: - Pipe Bootstrap func testPipeBootstrap() async throws { diff --git a/Tests/NIOPosixTests/NIOThreadPoolTest.swift b/Tests/NIOPosixTests/NIOThreadPoolTest.swift index 687ac8ea3d..0884a45538 100644 --- a/Tests/NIOPosixTests/NIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/NIOThreadPoolTest.swift @@ -190,6 +190,6 @@ class NIOThreadPoolTest: XCTestCase { let future = threadPool.runIfActive(eventLoop: eventLoop) { XCTFail("This shouldn't run because the pool is shutdown.") } - await XCTAssertThrowsError(try await future.get()) + await XCTAssertThrowsError { try await future.get() } } } diff --git a/Tests/NIOPosixTests/XCTest+AsyncAwait.swift b/Tests/NIOPosixTests/XCTest+AsyncAwait.swift index 810d9185a9..f1d4689994 100644 --- a/Tests/NIOPosixTests/XCTest+AsyncAwait.swift +++ b/Tests/NIOPosixTests/XCTest+AsyncAwait.swift @@ -44,7 +44,7 @@ import XCTest @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) internal func XCTAssertThrowsError( - _ expression: @autoclosure () async throws -> T, + _ expression: () async throws -> T, file: StaticString = #filePath, line: UInt = #line, verify: (Error) -> Void = { _ in }