Skip to content

Commit

Permalink
SocketPool + IdentifiableContinuation
Browse files Browse the repository at this point in the history
  • Loading branch information
swhitty committed Apr 14, 2024
1 parent ea36d7a commit a8b22e4
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 70 deletions.
108 changes: 62 additions & 46 deletions FlyingSocks/Sources/SocketPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {

public func suspendSocket(_ socket: Socket, untilReadyFor events: Socket.Events) async throws {
guard state == .running || state == .ready else { throw Error("Not Ready") }
let continuation = Continuation()
defer { removeContinuation(continuation, for: socket.file) }
try appendContinuation(continuation, for: socket.file, events: events)
return try await continuation.value
return try await withIdentifiableThrowingContinuation(isolation: self) {
appendContinuation($0, for: socket.file, events: events)
} onCancel: { id in
Task {
await self.resumeContinuation(id: id, with: .failure(CancellationError()), for: socket.file)
}
}
}

private func getNotifications() async throws -> [EventNotification] {
Expand All @@ -128,19 +131,8 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {
}

private func processNotification(_ notification: EventNotification) {
let continuations = waiting.continuations(
for: notification.file,
events: notification.events
)

if notification.errors.isEmpty {
for c in continuations {
c.resume()
}
} else {
for c in continuations {
c.resume(throwing: .disconnected)
}
for id in waiting.continuationIDs(for: notification.file, events: notification.events) {
resumeContinuation(id: id, with: notification.result, for: notification.file)
}
}

Expand All @@ -160,8 +152,8 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {
loop = nil
}

typealias Continuation = CancellingContinuation<Void, SocketError>
private var loop: IdentifiableContinuation<Void, any Swift.Error>?
typealias Continuation = IdentifiableContinuation<Void, any Swift.Error>
private var loop: Continuation?
private var waiting = Waiting() {
didSet {
if !waiting.isEmpty, let continuation = loop {
Expand All @@ -179,26 +171,41 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {
}
}

private func cancelLoopContinuation(with id: IdentifiableContinuation<Void, any Swift.Error>.ID) {
private func cancelLoopContinuation(with id: Continuation.ID) {
if loop?.id == id {
loop?.resume(throwing: CancellationError())
loop = nil
}
}

private func appendContinuation(_ continuation: Continuation,
for socket: Socket.FileDescriptor,
events: Socket.Events) throws {
let events = waiting.appendContinuation(continuation,
for: socket,
events: events)
try queue.addEvents(events, for: socket)
private func appendContinuation(
_ continuation: Continuation,
for socket: Socket.FileDescriptor,
events: Socket.Events
) {
let events = waiting.appendContinuation(continuation, for: socket, events: events)
do {
try queue.addEvents(events, for: socket)
} catch {
resumeContinuation(
id: continuation.id,
with: .failure(error),
for: socket
)
}
}

private func removeContinuation(_ continuation: Continuation,
for socket: Socket.FileDescriptor) {
let events = waiting.removeContinuation(continuation, for: socket)
try? queue.removeEvents(events, for: socket)
private func resumeContinuation(
id: Continuation.ID,
with result: Result<Void, any Swift.Error>,
for socket: Socket.FileDescriptor
) {
do {
let events = waiting.resumeContinuation(id: id, with: result, for: socket)
try queue.removeEvents(events, for: socket)
} catch {
logger.logError("resumeContinuation queue.removeEvents: \(error.localizedDescription)")
}
}

private struct Error: LocalizedError {
Expand All @@ -210,7 +217,7 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {
}

struct Waiting {
private var storage: [Socket.FileDescriptor: [Continuation: Socket.Events]] = [:]
private var storage: [Socket.FileDescriptor: [Continuation.ID: (continuation: Continuation, events: Socket.Events)]] = [:]

var isEmpty: Bool { storage.isEmpty }

Expand All @@ -219,41 +226,50 @@ public final actor SocketPool<Queue: EventQueue>: AsyncSocketPool {
for socket: Socket.FileDescriptor,
events: Socket.Events) -> Socket.Events {
var entries = storage[socket] ?? [:]
entries[continuation] = events
entries[continuation.id] = (continuation, events)
storage[socket] = entries
return entries.values.reduce(Socket.Events()) {
$0.union($1)
$0.union($1.events)
}
}

// Removes continuation returning any events that are no longer being waited
mutating func removeContinuation(_ continuation: Continuation,
// Resumes and removes continuation, returning any events that are no longer being waited
mutating func resumeContinuation(id: Continuation.ID,
with result: Result<Void, any Swift.Error>,
for socket: Socket.FileDescriptor) -> Socket.Events {
var entries = storage[socket] ?? [:]
guard let events = entries[continuation] else { return [] }
entries[continuation] = nil
guard let (continuation, events) = entries.removeValue(forKey: id) else { return [] }
continuation.resume(with: result)
storage[socket] = entries.isEmpty ? nil : entries
let remaining = entries.values.reduce(Socket.Events()) {
$0.union($1)
$0.union($1.events)
}
return events.filter { !remaining.contains($0) }
}

func continuations(for socket: Socket.FileDescriptor, events: Socket.Events) -> [Continuation] {
func continuationIDs(for socket: Socket.FileDescriptor, events: Socket.Events) -> [Continuation.ID] {
let entries = storage[socket] ?? [:]
return entries.compactMap { c, ev in
if events.intersection(ev).isEmpty {
return entries.compactMap { id, ev in
if events.intersection(ev.events).isEmpty {
return nil
} else {
return c
return id
}
}
}

func cancellAll() {
for continuation in storage.values.flatMap(\.keys) {
continuation.cancel()
mutating func cancellAll() {
let continuations = storage.values.flatMap(\.values).map(\.continuation)
storage = [:]
for continuation in continuations {
continuation.resume(throwing: CancellationError())
}
}
}
}

private extension EventNotification {
var result: Result<Void, any Swift.Error> {
errors.isEmpty ? .success(()) : .failure(SocketError.disconnected)
}
}
70 changes: 46 additions & 24 deletions FlyingSocks/Tests/SocketPoolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
//

@testable import FlyingSocks
@_spi(Private) import struct FlyingSocks.IdentifiableContinuation
@_spi(Private) import func FlyingSocks.withIdentifiableThrowingContinuation
import XCTest

final class SocketPoolTests: XCTestCase {

typealias Continuation = CancellingContinuation<Void, SocketError>
typealias Continuation = IdentifiableContinuation<Void, any Swift.Error>
typealias Waiting = SocketPool<MockEventQueue>.Waiting

#if canImport(Darwin)
Expand Down Expand Up @@ -161,24 +163,24 @@ final class SocketPoolTests: XCTestCase {
) { XCTAssertEqual($0, .disconnected) }
}

func testWaiting_IsEmpty() {
let cn = Continuation()
func testWaiting_IsEmpty() async {
let cn = await Continuation.make()

var waiting = Waiting()
XCTAssertTrue(waiting.isEmpty)

_ = waiting.appendContinuation(cn, for: .validMock, events: .read)
XCTAssertFalse(waiting.isEmpty)

_ = waiting.removeContinuation(cn, for: .validMock)
_ = waiting.resumeContinuation(id: cn.id, with: .success(()), for: .validMock)
XCTAssertTrue(waiting.isEmpty)
}

func testWaitingEvents() {
func testWaitingEvents() async {
var waiting = Waiting()
let cnRead = Continuation()
let cnRead1 = Continuation()
let cnWrite = Continuation()
let cnRead = await Continuation.make()
let cnRead1 = await Continuation.make()
let cnWrite = await Continuation.make()

XCTAssertEqual(
waiting.appendContinuation(cnRead, for: .validMock, events: .read),
Expand All @@ -193,51 +195,51 @@ final class SocketPoolTests: XCTestCase {
[.read, .write]
)
XCTAssertEqual(
waiting.removeContinuation(.init(), for: .validMock),
waiting.resumeContinuation(id: .init(), with: .success(()), for: .validMock),
[]
)
XCTAssertEqual(
waiting.removeContinuation(cnWrite, for: .validMock),
waiting.resumeContinuation(id: cnWrite.id, with: .success(()), for: .validMock),
[.write]
)
XCTAssertEqual(
waiting.removeContinuation(cnRead, for: .validMock),
waiting.resumeContinuation(id: cnRead.id, with: .success(()), for: .validMock),
[]
)
XCTAssertEqual(
waiting.removeContinuation(cnRead1, for: .validMock),
waiting.resumeContinuation(id: cnRead1.id, with: .success(()), for: .validMock),
[.read]
)
}

func testWaitingContinuations() {
func testWaitingContinuations() async {
var waiting = Waiting()
let cnRead = Continuation()
let cnRead1 = Continuation()
let cnWrite = Continuation()
let cnRead = await Continuation.make()
let cnRead1 = await Continuation.make()
let cnWrite = await Continuation.make()

_ = waiting.appendContinuation(cnRead, for: .validMock, events: .read)
_ = waiting.appendContinuation(cnRead1, for: .validMock, events: .read)
_ = waiting.appendContinuation(cnWrite, for: .validMock, events: .write)

XCTAssertEqual(
Set(waiting.continuations(for: .validMock, events: .read)),
[cnRead1, cnRead]
Set(waiting.continuationIDs(for: .validMock, events: .read)),
[cnRead1.id, cnRead.id]
)
XCTAssertEqual(
Set(waiting.continuations(for: .validMock, events: .write)),
[cnWrite]
Set(waiting.continuationIDs(for: .validMock, events: .write)),
[cnWrite.id]
)
XCTAssertEqual(
Set(waiting.continuations(for: .validMock, events: .connection)),
[cnRead1, cnRead, cnWrite]
Set(waiting.continuationIDs(for: .validMock, events: .connection)),
[cnRead1.id, cnRead.id, cnWrite.id]
)
XCTAssertEqual(
Set(waiting.continuations(for: .validMock, events: [])),
Set(waiting.continuationIDs(for: .validMock, events: [])),
[]
)
XCTAssertEqual(
Set(waiting.continuations(for: .invalid, events: .connection)),
Set(waiting.continuationIDs(for: .invalid, events: .connection)),
[]
)
}
Expand Down Expand Up @@ -309,3 +311,23 @@ final class MockEventQueue: EventQueue, @unchecked Sendable {
}
}
}


extension IdentifiableContinuation {
static func make() async -> IdentifiableContinuation<T, any Error> {
await Host().makeThrowingContinuation()
}

private actor Host {
func makeThrowingContinuation() async -> IdentifiableContinuation<T, any Error> {
await withCheckedContinuation { outer in
Task {
try? await withIdentifiableThrowingContinuation(isolation: self) {
outer.resume(returning: $0)
} onCancel: { _ in }
}
}
}
}
}

0 comments on commit a8b22e4

Please sign in to comment.