diff --git a/src/TaskExtensionsIfFulfilled.cs b/src/TaskExtensionsIfFulfilled.cs index 850551b..fd9e81c 100644 --- a/src/TaskExtensionsIfFulfilled.cs +++ b/src/TaskExtensionsIfFulfilled.cs @@ -1,11 +1,8 @@ using System; -using System.Threading; using System.Threading.Tasks; namespace RLC.TaskChaining; -using static TaskStatics; - public static partial class TaskExtensions { /// @@ -50,5 +47,19 @@ public static Task IfFulfilled(this Task task, Func> func /// The function to execute if the task is fulfilled. /// The task. public static Task IfFulfilled(this Task task, Func func) - => task.IfFulfilled(value => Task.FromResult(value).Then(func).Then(_ => value, _ => value)); + => task.ContinueWith(continuationTask => + { + if (continuationTask.IsFaulted || continuationTask.IsCanceled) + { + return continuationTask; + } + else + { + return continuationTask.Then(async value => + { + await func(value); + return value; + }); + } + }).Unwrap(); } diff --git a/tests/unit/TaskChainingIfFulfilledTests.cs b/tests/unit/TaskChainingIfFulfilledTests.cs index f620c94..ddd5f69 100644 --- a/tests/unit/TaskChainingIfFulfilledTests.cs +++ b/tests/unit/TaskChainingIfFulfilledTests.cs @@ -202,5 +202,38 @@ public async void ItShouldContinueAsyncTasks() Assert.Equal(expectedValue, actualValue); } + + [Fact] + public async Task ItShouldAwaitAsyncSideEffects() + { + int coin = 5; + int testValue = 12345; + int expectedSideEffectValue = 12345; + string expectedValue = "12345"; + int? sideEffectIntValue = null; + string? sideEffectStringValue = null; + Func func = value => + { + return Task.FromResult(value) + .Delay(TimeSpan.FromSeconds(1)) + .Then(int.Parse) + .Then(v => + { + sideEffectIntValue = v; + return v; + }); + }; + + string actualValue = await Task.FromResult(testValue) + .Then(val => val.ToString()) + .IfFulfilled(value => coin < 6 + ? func(value) + : Task.FromResult(value) + ); + + Assert.Equal(expectedValue, actualValue); + Assert.Equal(expectedSideEffectValue, sideEffectIntValue); + Assert.Null(sideEffectStringValue); + } } }