diff --git a/Sources/ToolsProtocolsSwiftExtensions/AsyncQueue.swift b/Sources/ToolsProtocolsSwiftExtensions/AsyncQueue.swift index 163420e87..4a1024bf2 100644 --- a/Sources/ToolsProtocolsSwiftExtensions/AsyncQueue.swift +++ b/Sources/ToolsProtocolsSwiftExtensions/AsyncQueue.swift @@ -131,7 +131,7 @@ public final class AsyncQueue: Sendable { // No dependency continue } - if metadata.isDependency(of: metadata), let lastPendingTask = pendingTasks.last { + if pendingMetadata.isDependency(of: pendingMetadata), let lastPendingTask = pendingTasks.last { // This kind of task depends on all other tasks of the same kind finishing. It is sufficient to just wait on // the last task with this metadata, it will have all the other tasks with the same metadata as transitive // dependencies. @@ -152,20 +152,20 @@ public final class AsyncQueue: Sendable { // operation. Otherwise the assumption that the task will never throw // if `operation` does not throw, which we are making in `async` does // not hold anymore. - for dependency in dependencies { - await dependency.task.waitForCompletion() + defer { + pendingTasks.withLock { tasksByMetadata in + tasksByMetadata[metadata, default: []].removeAll(where: { $0.id == id }) + if tasksByMetadata[metadata]?.isEmpty ?? false { + tasksByMetadata[metadata] = nil + } + } } - let result = try await operation() - - pendingTasks.withLock { tasksByMetadata in - tasksByMetadata[metadata, default: []].removeAll(where: { $0.id == id }) - if tasksByMetadata[metadata]?.isEmpty ?? false { - tasksByMetadata[metadata] = nil - } + for dependency in dependencies { + await dependency.task.waitForCompletion() } - return result + return try await operation() } tasksByMetadata[metadata, default: []].append(PendingTask(task: task, id: id)) diff --git a/Tests/ToolsProtocolsSwiftExtensionsTests/AsyncQueueTests.swift b/Tests/ToolsProtocolsSwiftExtensionsTests/AsyncQueueTests.swift new file mode 100644 index 000000000..83ed415c3 --- /dev/null +++ b/Tests/ToolsProtocolsSwiftExtensionsTests/AsyncQueueTests.swift @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import Testing +@_spi(SourceKitLSP) import ToolsProtocolsSwiftExtensions + +struct AsyncQueueTests { + /// Two metadata kinds where: + /// - `.concurrent` is *not* self-serializing (concurrent with itself) + /// - `.serial` is self-serializing + /// - `.concurrent` is a dependency of `.serial + /// + /// In this configuration, a `.serial` task depends on a bucket whose + /// entries do not depend on each other, so the dependency list cannot + /// collapse to just the last entry — every concurrent task in the bucket + /// must be awaited individually. + private enum Meta: Hashable, Sendable, DependencyTracker { + case concurrent + case serial + + func isDependency(of other: Meta) -> Bool { + switch (self, other) { + case (.concurrent, .concurrent): return false + case (.concurrent, .serial): return true + case (.serial, .concurrent): return false + case (.serial, .serial): return true + } + } + } + + /// A task depending on a non-self-serializing bucket must wait on every + /// task in that bucket, not just the last one. + @Test func serialTaskWaitsForAllConcurrentDependencies() async throws { + let queue = AsyncQueue() + + // Three concurrent tasks held until we yield to their respective streams. + let (stream1, cont1) = AsyncStream.makeStream() + let (stream2, cont2) = AsyncStream.makeStream() + let (stream3, cont3) = AsyncStream.makeStream() + let (startedStream, startedCont) = AsyncStream.makeStream() + + for stream in [stream1, stream2, stream3] { + queue.async(metadata: .concurrent) { + startedCont.yield() + for await _ in stream {} + } + } + + // Wait for all three concurrent tasks to be in flight before scheduling + // the serial dependent — otherwise the bucket might not have all three + // entries when the serial task computes its dependencies. + var startCount = 0 + for await _ in startedStream { + startCount += 1 + if startCount == 3 { break } + } + + let serialRan = ThreadSafeBox(initialValue: false) + let serialTask = queue.async(metadata: .serial) { + serialRan.value = true + } + + // Release only the last concurrent task. The serial task must still wait + // for the first two before running. + cont3.finish() + + // Give the serial task time to (incorrectly) run. The first two + // concurrent tasks are still blocked, so the serial task must not have + // run yet. + try await Task.sleep(for: .milliseconds(200)) + #expect( + !serialRan.value, + "Serial task ran before all concurrent dependencies completed" + ) + + // Release the remaining concurrent tasks; the serial task should now run. + cont1.finish() + cont2.finish() + await serialTask.value + #expect(serialRan.value) + } +}