Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 90 additions & 94 deletions Sources/ToolsProtocolsSwiftExtensions/AsyncUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,127 +204,123 @@ extension Collection where Self: Sendable, Element: Sendable {
}
}

/// Executes `body`. If it doesn't finish after `duration`, throws a `TimeoutError` and cancels `body`.
@_spi(SourceKitLSP)
public enum WithTimeoutResult<T: Sendable>: Sendable {
case result(T)
case timedOut
}

/// Executes `body` with a `duration` timeout.
///
/// `TimeoutError` is thrown immediately an the function does not wait for `body` to honor the cancellation.
/// Returns `.result(value)` if `body` finishes within `duration`, otherwise `.timedOut`.
///
/// If a `handle` is passed in and this `withTimeout` call times out, the thrown `TimeoutError` contains this handle.
/// This way a caller can identify whether this call to `withTimeout` timed out or if a nested call timed out.
package func withTimeout<T: Sendable>(
_ duration: Duration,
handle: TimeoutHandle? = nil,
_ body: @escaping @Sendable () async throws -> T
) async throws -> T {
// Get the priority with which to launch the body task here so that we can pass the same priority as the initial
// priority to `withTaskPriorityChangedHandler`. Otherwise, we can get into a race condition where bodyTask gets
// launched with a low priority, then the priority gets elevated before we call with `withTaskPriorityChangedHandler`,
// we thus don't receive a `taskPriorityChanged` and hence never increase the priority of `bodyTask`.
/// On timeout: if `resultReceivedAfterTimeout` is provided, `body` keeps running and its
/// eventual result is passed to that callback. Otherwise, `body` is cancelled.
@_spi(SourceKitLSP)
public func withTimeoutResult<T: Sendable>(
_ timeout: Duration,
body: @escaping @Sendable () async throws -> T,
resultReceivedAfterTimeout: (@Sendable (_ result: T) async -> Void)? = nil
) async rethrows -> WithTimeoutResult<T> {
// Capture the priority here so it stays consistent across `bodyTask`, timeoutTask`,
// and `withTaskPriorityChangedHandler`'s initial state.
let priority = Task.currentPriority
var mutableTasks: [Task<Void, Error>] = []
let stream = AsyncThrowingStream<T, Error> { continuation in
let bodyTask = Task<Void, Error>(priority: priority) {
do {
let result = try await body()
continuation.yield(result)
} catch {
continuation.yield(with: .failure(error))
}
}

let timeoutTask = Task(priority: priority) {
try await Task.sleep(for: duration)
continuation.yield(with: .failure(TimeoutError(handle: handle)))
bodyTask.cancel()
let (stream, continuation) = AsyncStream<WithTimeoutResult<Result<T, any Error>>>.makeStream()
let bodyTask = Task(priority: priority) {
do {
let value = try await body()
continuation.yield(.result(.success(value)))
return value
} catch {
continuation.yield(.result(.failure(error)))
throw error
}
mutableTasks = [bodyTask, timeoutTask]
}

let tasks = mutableTasks

defer {
// Be extra careful and ensure that we don't leave `bodyTask` or `timeoutTask` running when `withTimeout` finishes,
// eg. if `withTaskPriorityChangedHandler` adds some behavior that never executes `body` if the task gets cancelled.
for task in tasks {
task.cancel()
}
let timeoutTask = Task(priority: priority) {
do { try await Task.sleep(for: timeout) } catch { return }
continuation.yield(.timedOut)
}

// `bodyTask` is intentionally not cancelled here: it must keep running so the late-result
// dispatcher can deliver its value. Cancellation happens at the specific sites that own that
// decision.
defer { timeoutTask.cancel() }

return try await withTaskPriorityChangedHandler(initialPriority: priority) {
for try await value in stream {
return value
}
// The only reason for the loop above to terminate is if the Task got cancelled or if the stream finishes
// (which it never does).
if Task.isCancelled {
// Throwing a `CancellationError` will make us return from `withTimeout`. We will cancel the `bodyTask` from the
// `defer` method above.
throw CancellationError()
} else {
preconditionFailure("Continuation never finishes")
}
} taskPriorityChanged: {
for task in tasks {
Task(priority: Task.currentPriority) {
_ = try? await task.value
for await value in stream {
switch value {
case .result(let r):
return try .result(r.get())
case .timedOut:
if let resultReceivedAfterTimeout {
// Late-result dispatch: await body and deliver via callback.
Task { try? await resultReceivedAfterTimeout(bodyTask.value) }
} else {
bodyTask.cancel()
}
return .timedOut
}
}
// The for-await exits without a return only if the consuming task is cancelled.
guard Task.isCancelled else { preconditionFailure("Continuation never finishes") }

bodyTask.cancel()
throw CancellationError()
} taskPriorityChanged: {
// Spawning fresh tasks that await `bodyTask` and `timeoutTask` forces the runtime to
// escalate their priorities via the await chain so `body`'s `Task.currentPriority`
// reflects the elevated value.
let newPriority = Task.currentPriority
Task(priority: newPriority) { _ = await bodyTask.result }
Task(priority: newPriority) { _ = await timeoutTask.value }
}
}

/// Executes `body`. If it doesn't finish after `duration`, throws a `TimeoutError` and cancels `body`.
///
/// `TimeoutError` is thrown immediately; the function does not wait for `body` to honor the cancellation.
///
/// If a `handle` is passed in and this `withTimeout` call times out, the thrown `TimeoutError` contains this handle.
/// This way a caller can identify whether this call to `withTimeout` timed out or if a nested call timed out.
@_spi(SourceKitLSP) @inlinable
public func withTimeout<T: Sendable>(
_ duration: Duration,
handle: TimeoutHandle? = nil,
_ body: @escaping @Sendable () async throws -> T
) async throws -> T {
switch try await withTimeoutResult(duration, body: body) {
case .result(let value): return value
case .timedOut: throw TimeoutError(handle: handle)
}
}

/// Executes `body`. If it doesn't finish after `duration`, return `nil` and continue running body. When `body` returns
/// a value after the timeout, `resultReceivedAfterTimeout` is called.
/// a value or throws an error after the timeout, `resultReceivedAfterTimeout` is called with the outcome.
///
/// - Important: `body` will not be cancelled when the timeout is received. Use the other overload of `withTimeout` if
/// `body` should be cancelled after `timeout`.
package func withTimeout<T: Sendable>(
@_spi(SourceKitLSP) @inlinable
public func withTimeout<T: Sendable>(
_ timeout: Duration,
body: @escaping @Sendable () async throws -> T,
resultReceivedAfterTimeout: @escaping @Sendable (_ result: T) async -> Void
) async throws -> T? {
let didHitTimeout = ThreadSafeBox<Bool>(initialValue: false)

let stream = AsyncThrowingStream<T?, Error> { continuation in
Task {
try await Task.sleep(for: timeout)
didHitTimeout.withLock { $0 = true }
continuation.yield(nil)
}

Task {
do {
let result = try await body()
if didHitTimeout.value {
await resultReceivedAfterTimeout(result)
}
continuation.yield(result)
Copy link
Copy Markdown
Member Author

@rintaro rintaro May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was actually a correctness issue here.
After didHitTimeout.value check passes but before continuation.yield(result), the timeout task might have set didHitTimeout and continuation.yield(nil). So the result is timeout but resultReceivedAfterTimeout never get called.

} catch {
continuation.yield(with: .failure(error))
}
}
}

for try await value in stream {
return value
}
// The only reason for the loop above to terminate is if the Task got cancelled or if the continuation finishes
// (which it never does).
if Task.isCancelled {
throw CancellationError()
} else {
preconditionFailure("Continuation never finishes")
) async rethrows -> T? {
switch try await withTimeoutResult(timeout, body: body, resultReceivedAfterTimeout: resultReceivedAfterTimeout) {
case .result(let value): return value
case .timedOut: return nil
}
}

/// Same as `withTimeout` above but allows `body` to return an optional value.
package func withTimeout<T: Sendable>(
@_spi(SourceKitLSP) @inlinable
public func withTimeout<T: Sendable>(
_ timeout: Duration,
body: @escaping @Sendable () async throws -> T?,
resultReceivedAfterTimeout: @escaping @Sendable (_ result: T?) async -> Void
) async throws -> T? {
let result: T?? = try await withTimeout(timeout, body: body, resultReceivedAfterTimeout: resultReceivedAfterTimeout)
switch result {
case .none: return nil
case .some(.none): return nil
case .some(.some(let value)): return value
) async rethrows -> T? {
switch try await withTimeoutResult(timeout, body: body, resultReceivedAfterTimeout: resultReceivedAfterTimeout) {
case .result(let value): return value
case .timedOut: return nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,50 @@
//
//===----------------------------------------------------------------------===//

import Synchronization

/// Runs `operation`. If the task's priority changes while the operation is running, calls `taskPriorityChanged`.
///
/// Since Swift Concurrency doesn't support direct observation of a task's priority, this polls the task's priority at
/// `pollingInterval`.
/// The function assumes that the original priority of the task is `initialPriority`. If the task priority changed
/// compared to `initialPriority`, the `taskPriorityChanged` will be called.
// Workaround formatter issue: https://github.com/swiftlang/swift-format/issues/1081
// swift-format-ignore
@_spi(SourceKitLSP) public func withTaskPriorityChangedHandler<T: Sendable>(
/// On platforms with the runtime-provided priority escalation hook (SwiftStdlib 6.2+), this delegates
/// to `withTaskPriorityEscalationHandler` and reacts immediately. Otherwise it polls
/// `Task.currentPriority` every `pollingInterval` and assumes the original priority is `initialPriority`;
/// the `taskPriorityChanged` callback fires when the polled priority differs.
@available(macOS, deprecated: 26.0, message: "Use withTaskPriorityEscalationHandler")
@available(iOS, deprecated: 26.0, message: "Use withTaskPriorityEscalationHandler")
@available(macCatalyst, deprecated: 26.0, message: "Use withTaskPriorityEscalationHandler")
@_spi(SourceKitLSP) @inlinable public func withTaskPriorityChangedHandler<T: Sendable>(
initialPriority: TaskPriority = Task.currentPriority,
pollingInterval: Duration = .seconds(0.1),
@_inheritActorContext operation: nonisolated(nonsending) @escaping @Sendable () async throws -> T,
taskPriorityChanged: @escaping @Sendable () -> Void
) async throws -> T {
let lastPriority = ThreadSafeBox(initialValue: initialPriority)
let result: T? = try await withThrowingTaskGroup(of: Optional<T>.self) { taskGroup in
) async rethrows -> T {
if #available(macOS 26, iOS 26, macCatalyst 26, *) {
return try await withTaskPriorityEscalationHandler(
operation: operation,
onPriorityEscalated: { _, _ in taskPriorityChanged() }
)
} else {
return try await withTaskPriorityChangedHandlerLegacy(
initialPriority: initialPriority,
pollingInterval: pollingInterval,
operation: operation,
taskPriorityChanged: taskPriorityChanged
)
}
}

/// Polling-based fallback for ``withTaskPriorityChangedHandler`` on platforms without
/// `withTaskPriorityEscalationHandler`. Exposed under `@_spi(Testing)` so tests can
/// exercise this path even on platforms where the inlinable wrapper would dispatch to
/// the stdlib hook.
@_spi(Testing) public func withTaskPriorityChangedHandlerLegacy<T: Sendable>(
initialPriority: TaskPriority,
pollingInterval: Duration,
@_inheritActorContext operation: nonisolated(nonsending) @escaping @Sendable () async throws -> T,
taskPriorityChanged: @escaping @Sendable () -> Void
) async rethrows -> T {
let lastPriority = RefBox(Atomic<TaskPriority.RawValue>(initialPriority.rawValue))
return try await withThrowingTaskGroup(of: Optional<T>.self) { taskGroup in
defer {
// We leave this closure when either we have received a result or we registered cancellation. In either case, we
// want to make sure that we don't leave the body task or the priority watching task running.
Expand All @@ -36,15 +64,8 @@
if Task.isCancelled {
break
}
let newPriority = Task.currentPriority
let didChange = lastPriority.withLock { lastPriority in
if newPriority != lastPriority {
lastPriority = newPriority
return true
}
return false
}
if didChange {
let newPriority = Task.currentPriority.rawValue
if newPriority != lastPriority.value.exchange(newPriority, ordering: .relaxed) {
taskPriorityChanged()
}
do {
Expand All @@ -58,16 +79,12 @@
taskGroup.addTask {
try await operation()
}
// The first task that watches the priority never finishes unless it is cancelled, so we are effectively await the
// `operation` task here.
// We do need to await the observation task as well so that priority escalation also affects the observation task.
// The watcher loops forever until cancelled, so iterating the group effectively awaits
// `operation`. The watcher is structured into the same task group so it inherits the
// parent's priority and is automatically escalated alongside `operation`.
for try await case let value? in taskGroup {
return value
}
return nil
}
guard let result else {
throw CancellationError()
preconditionFailure("Task group exits only via operation's value or throw")
}
return result
}
56 changes: 55 additions & 1 deletion Tests/ToolsProtocolsSwiftExtensionsTests/AsyncUtilsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//

@_spi(SourceKitLSP) import SKLogging
@_spi(SourceKitLSP) import ToolsProtocolsSwiftExtensions
@_spi(SourceKitLSP) @_spi(Testing) import ToolsProtocolsSwiftExtensions
import ToolsProtocolsTestSupport
import XCTest

Expand Down Expand Up @@ -73,4 +73,58 @@ final class AsyncUtilsTests: XCTestCase {
try await task.value
}.value
}

func testWithTaskPriorityChangedHandlerLegacyReturnsOptionalNilFromOperation() async throws {
// When the operation's `T` is itself an `Optional`, verify `nil` return
// value is propagated as the operation's result.
let result: String? = try await withTaskPriorityChangedHandlerLegacy(
initialPriority: Task.currentPriority,
pollingInterval: .seconds(0.1),
operation: {
let value: String? = nil
return value
},
taskPriorityChanged: {}
)
XCTAssertNil(result)
}

func testWithTaskPriorityChangedHandlerLegacyDetectsPriorityEscalation() async throws {
let started = self.expectation(description: "Operation started")
let callbackFired = ThreadSafeBox(initialValue: false)
let task = Task(priority: .background) {
try await withTaskPriorityChangedHandlerLegacy(
initialPriority: .background,
pollingInterval: .seconds(0.05),
operation: {
started.fulfill()
try await repeatUntilExpectedResult(sleepInterval: .seconds(0.1)) {
return callbackFired.value
}
},
taskPriorityChanged: {
callbackFired.withLock { $0 = true }
}
)
}
try await fulfillmentOfOrThrow(started)
try await Task(priority: .high) {
try await task.value
}.value
XCTAssertTrue(callbackFired.value)
}

func testWithTaskPriorityChangedHandlerLegacyRethrowsError() async throws {
struct TestError: Error {}
await assertThrowsError(
try await withTaskPriorityChangedHandlerLegacy(
initialPriority: Task.currentPriority,
pollingInterval: .seconds(0.1),
operation: { throw TestError() },
taskPriorityChanged: {}
)
) { error in
XCTAssert(error is TestError, "Received unexpected error \(error)")
}
}
}
Loading