Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public async Task<List<TenantResource>> GetTenants(CancellationToken cancellatio
options.Environment = CloudConfiguration.ArmEnvironment;
var client = new ArmClient(await GetCredential(cancellationToken), default, options);

await foreach (var tenant in client.GetTenants())
await foreach (var tenant in client.GetTenants().WithCancellation(cancellationToken))
{
results.Add(tenant);
}
Expand Down
18 changes: 18 additions & 0 deletions servers/Azure.Mcp.Server/docs/new-command.md
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,28 @@ public interface IMyService
**Service Implementation Requirements:**
- Pass the `CancellationToken` parameter to all async method calls
- Use `cancellationToken: cancellationToken` when calling Azure SDK methods
- Use `.WithCancellation(cancellationToken)` when iterating over async enumerables with `await foreach`
- Always include `CancellationToken cancellationToken` as the final parameter (only use a default value if and only if other parameters have default values)
- Force callers to explicitly provide a CancellationToken
- Never pass `CancellationToken.None` or `default` as a value to a `CancellationToken` method parameter

**Example - Async Enumerable Pattern:**
```csharp
// ✅ Correct: Use .WithCancellation() for async enumerables
var subscription = _armClient.GetSubscriptionResource(new($"/subscriptions/{_subscriptionId}"));
await foreach (var resourceGroup in subscription.GetResourceGroups().WithCancellation(cancellationToken))
{
return resourceGroup.Data.Name;
}

// ❌ Wrong: Missing .WithCancellation()
var subscription = _armClient.GetSubscriptionResource(new($"/subscriptions/{_subscriptionId}"));
await foreach (var resourceGroup in subscription.GetResourceGroups())
{
return resourceGroup.Data.Name;
}
```

**Unit Testing Requirements:**
- **Mock setup**: Use `Arg.Any<CancellationToken>()` for CancellationToken parameters in mock setups
- **Product code invocation**: Use `TestContext.Current.CancellationToken` when invoking product code from unit tests
Expand Down
9 changes: 5 additions & 4 deletions tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ private async Task<CosmosDBAccountResource> GetCosmosAccountAsync(
string subscription,
string accountName,
string? tenant = null,
RetryPolicyOptions? retryPolicy = null)
RetryPolicyOptions? retryPolicy = null,
CancellationToken cancellationToken = default)
{
ValidateRequiredParameters((nameof(subscription), subscription), (nameof(accountName), accountName));

var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);

await foreach (var account in subscriptionResource.GetCosmosDBAccountsAsync())
await foreach (var account in subscriptionResource.GetCosmosDBAccountsAsync(cancellationToken))
{
if (account.Data.Name == accountName)
{
Expand Down Expand Up @@ -74,7 +75,7 @@ private async Task<CosmosClient> CreateCosmosClientWithAuth(
switch (authMethod)
{
case AuthMethod.Key:
var cosmosAccount = await GetCosmosAccountAsync(subscription, accountName, tenant);
var cosmosAccount = await GetCosmosAccountAsync(subscription, accountName, tenant, cancellationToken: cancellationToken);
var keys = await cosmosAccount.GetKeysAsync(cancellationToken);
cosmosClient = new CosmosClient(
string.Format(CosmosBaseUri, accountName),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ public async Task<List<Namespace>> GetNamespacesAsync(
throw new InvalidOperationException($"Resource group '{resourceGroup}' not found");
}

await foreach (var namespaceResource in resourceGroupResource.Value.GetEventHubsNamespaces())
await foreach (var namespaceResource in resourceGroupResource.Value.GetEventHubsNamespaces().WithCancellation(cancellationToken))
{
namespaces.Add(ConvertToNamespace(namespaceResource.Data, resourceGroup));
}
}
else
{
// Get namespaces from all resource groups in subscription
await foreach (var rg in subscriptionResource.GetResourceGroups())
await foreach (var rg in subscriptionResource.GetResourceGroups().WithCancellation(cancellationToken))
{
await foreach (var namespaceResource in rg.GetEventHubsNamespaces())
await foreach (var namespaceResource in rg.GetEventHubsNamespaces().WithCancellation(cancellationToken))
{
namespaces.Add(ConvertToNamespace(namespaceResource.Data, rg.Data.Name));
}
Expand Down Expand Up @@ -313,7 +313,7 @@ public async Task<List<EventHub>> GetEventHubsAsync(

var eventHubList = new List<EventHub>();

await foreach (var eventHub in namespaceResource.Value.GetEventHubs())
await foreach (var eventHub in namespaceResource.Value.GetEventHubs().WithCancellation(cancellationToken))
{
eventHubList.Add(ConvertToEventHub(eventHub.Data, resourceGroup));
}
Expand Down Expand Up @@ -692,7 +692,7 @@ public async Task<List<ConsumerGroup>> GetConsumerGroupsAsync(

var consumerGroups = new List<ConsumerGroup>();

await foreach (var consumerGroup in eventHubResource.Value.GetEventHubsConsumerGroups())
await foreach (var consumerGroup in eventHubResource.Value.GetEventHubsConsumerGroups().WithCancellation(cancellationToken))
{
consumerGroups.Add(ConvertToConsumerGroup(consumerGroup.Data, resourceGroup, namespaceName, eventHubName));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Task<List<FileShareInfo>> ListFileSharesAsync(
}

var collection = resourceGroupResource.GetFileShares();
await foreach (var fileShareResource in collection)
await foreach (var fileShareResource in collection.WithCancellation(cancellationToken))
{
fileShares.Add(FileShareInfo.FromResource(fileShareResource));
}
Expand Down Expand Up @@ -488,7 +488,7 @@ public async Task<FileShareSnapshotInfo> GetSnapshotAsync(
var fileShareResource = await resourceGroupResource.Value.GetFileShares().GetAsync(fileShareName, cancellationToken);
var snapshotCollection = fileShareResource.Value.GetFileShareSnapshots();

await foreach (var snapshotResource in snapshotCollection)
await foreach (var snapshotResource in snapshotCollection.WithCancellation(cancellationToken))
{
if (snapshotResource.Data.Name.Equals(snapshotId, StringComparison.OrdinalIgnoreCase) ||
snapshotResource.Data.Id.ToString().Contains(snapshotId, StringComparison.OrdinalIgnoreCase))
Expand Down Expand Up @@ -539,7 +539,7 @@ public async Task<List<FileShareSnapshotInfo>> ListSnapshotsAsync(
var snapshotCollection = fileShareResource.Value.GetFileShareSnapshots();

var snapshots = new List<FileShareSnapshotInfo>();
await foreach (var snapshotResource in snapshotCollection)
await foreach (var snapshotResource in snapshotCollection.WithCancellation(cancellationToken))
{
snapshots.Add(FileShareSnapshotInfo.FromResource(snapshotResource));
}
Expand Down
15 changes: 8 additions & 7 deletions tools/Azure.Mcp.Tools.Foundry/src/Services/FoundryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ public async Task<ThreadListResult> ListThreads(
List<ThreadItem> threads = [];
try
{
await foreach (var thread in threadsIterator)
await foreach (var thread in threadsIterator.WithCancellation(cancellationToken))
{
threads.Add(new()
{
Expand Down Expand Up @@ -1172,7 +1172,7 @@ public async Task<ThreadGetMessagesResult> GetMessages(
{
List<PersistentThreadMessage> messages = [];
var messagesIterator = agentsClient.Messages.GetMessagesAsync(threadId, cancellationToken: cancellationToken);
await foreach (var message in messagesIterator)
await foreach (var message in messagesIterator.WithCancellation(cancellationToken))
{
messages.Add(message);
}
Expand Down Expand Up @@ -1499,7 +1499,7 @@ public async Task<List<AiResourceInformation>> ListAiResourcesAsync(
// List all AI resources in the subscription
await foreach (var account in subscriptionResource.GetCognitiveServicesAccountsAsync(cancellationToken: cancellationToken))
{
var resourceInfo = await BuildResourceInformation(account, subscriptionResource.Data.DisplayName);
var resourceInfo = await BuildResourceInformation(account, subscriptionResource.Data.DisplayName, cancellationToken);
resources.Add(resourceInfo);
}
}
Expand All @@ -1510,7 +1510,7 @@ public async Task<List<AiResourceInformation>> ListAiResourcesAsync(
{
if (account.Data.Id.ResourceGroupName?.Equals(resourceGroup, StringComparison.OrdinalIgnoreCase) == true)
{
var resourceInfo = await BuildResourceInformation(account, subscriptionResource.Data.DisplayName);
var resourceInfo = await BuildResourceInformation(account, subscriptionResource.Data.DisplayName, cancellationToken);
resources.Add(resourceInfo);
}
}
Expand Down Expand Up @@ -1554,7 +1554,7 @@ public async Task<AiResourceInformation> GetAiResourceAsync(
throw new Exception($"AI resource '{resourceName}' not found in resource group '{resourceGroup}'");
}

return await BuildResourceInformation(account.Value, subscriptionResource.Data.DisplayName);
return await BuildResourceInformation(account.Value, subscriptionResource.Data.DisplayName, cancellationToken);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -1596,7 +1596,8 @@ public AgentsGetSdkCodeSampleResult GetSdkCodeSample(string programmingLanguage)

private async Task<AiResourceInformation> BuildResourceInformation(
CognitiveServicesAccountResource account,
string subscriptionName)
string subscriptionName,
CancellationToken cancellationToken)
{
var resourceInfo = new AiResourceInformation
{
Expand All @@ -1613,7 +1614,7 @@ private async Task<AiResourceInformation> BuildResourceInformation(
// Get deployments for this resource
try
{
await foreach (var deployment in account.GetCognitiveServicesAccountDeployments())
await foreach (var deployment in account.GetCognitiveServicesAccountDeployments().WithCancellation(cancellationToken))
{
var deploymentInfo = new DeploymentInformation
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public sealed class FunctionAppService(
{
if (string.IsNullOrEmpty(resourceGroup))
{
await RetrieveAndAddFunctionApp(subscriptionResource.GetWebSitesAsync(cancellationToken), functionApps);
await RetrieveAndAddFunctionApp(subscriptionResource.GetWebSitesAsync(cancellationToken), functionApps, cancellationToken);
}
else
{
Expand All @@ -58,7 +58,7 @@ public sealed class FunctionAppService(
throw new Exception($"Resource group '{resourceGroup}' not found in subscription '{subscription}'");
}

await RetrieveAndAddFunctionApp(resourceGroupResource.Value.GetWebSites().GetAllAsync(cancellationToken: cancellationToken), functionApps);
await RetrieveAndAddFunctionApp(resourceGroupResource.Value.GetWebSites().GetAllAsync(cancellationToken: cancellationToken), functionApps, cancellationToken);
}

await _cacheService.SetAsync(CacheGroup, cacheKey, functionApps, s_cacheDuration, cancellationToken);
Expand Down Expand Up @@ -105,9 +105,9 @@ public sealed class FunctionAppService(
return functionApps;
}

private static async Task RetrieveAndAddFunctionApp(AsyncPageable<WebSiteResource> sites, List<FunctionAppInfo> functionApps)
private static async Task RetrieveAndAddFunctionApp(AsyncPageable<WebSiteResource> sites, List<FunctionAppInfo> functionApps, CancellationToken cancellationToken)
{
await foreach (var site in sites)
await foreach (var site in sites.WithCancellation(cancellationToken))
{
TryAddFunctionApp(site, functionApps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public async Task<List<TestRun>> GetLoadTestRunsFromTestIdAsync(
}

var testRuns = new List<TestRun>();
await foreach (var binaryData in loadTestRunResponse)
await foreach (var binaryData in loadTestRunResponse.WithCancellation(cancellationToken))
{
var testRun = JsonSerializer.Deserialize(binaryData.ToString(), LoadTestJsonContext.Default.TestRun);
if (testRun != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public override async Task<List<string>> GetAvailableRegionsAsync(string resourc
{
var quotas = subscription.GetModelsAsync(region);

await foreach (CognitiveServicesModel modelElement in quotas)
await foreach (CognitiveServicesModel modelElement in quotas.WithCancellation(cancellationToken))
{
var nameMatch = string.IsNullOrEmpty(_modelName) ||
(modelElement.Model?.Name == _modelName);
Expand Down Expand Up @@ -153,7 +153,7 @@ public override async Task<List<string>> GetAvailableRegionsAsync(string resourc
try
{
AsyncPageable<PostgreSqlFlexibleServerCapabilityProperties> result = subscription.ExecuteLocationBasedCapabilitiesAsync(region);
await foreach (var capability in result)
await foreach (var capability in result.WithCancellation(cancellationToken))
{
if (capability.SupportedServerEditions?.Any() == true)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesAsync(location, cancellationToken: cancellationToken);
var result = new List<UsageInfo>();

await foreach (ServiceAccountUsage item in usages)
await foreach (ServiceAccountUsage item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.LocalizedValue ?? item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (ComputeUsage item in usages)
await foreach (ComputeUsage item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.LocalizedValue ?? item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (var item in usages)
await foreach (var item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesWithLocationAsync(location, cancellationToken);

var result = new List<UsageInfo>();
await foreach (ContainerInstanceUsage item in usages)
await foreach (ContainerInstanceUsage item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.LocalizedValue ?? item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetHDInsightUsagesAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (HDInsightUsage item in usages)
await foreach (HDInsightUsage item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.LocalizedValue ?? item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetMachineLearningUsagesAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (var item in usages)
await foreach (var item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (var item in usages)
await foreach (var item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesBySubscriptionAsync(location, cancellationToken: cancellationToken);
var result = new List<UsageInfo>();

await foreach (QuotaUsageResult item in usages)
await foreach (QuotaUsageResult item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.Value ?? string.Empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public override async Task<List<UsageInfo>> GetUsageForLocationAsync(string loca
var usages = subscription.GetUsagesByLocationAsync(location, cancellationToken);
var result = new List<UsageInfo>();

await foreach (var item in usages)
await foreach (var item in usages.WithCancellation(cancellationToken))
{
result.Add(new UsageInfo(
Name: item.Name?.Value ?? string.Empty,
Expand Down
4 changes: 2 additions & 2 deletions tools/Azure.Mcp.Tools.Redis/src/Services/RedisService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ private async Task<IEnumerable<Resource>> ListAcrResourcesAsync(SubscriptionReso
try
{
var accessPolicyAssignmentCollection = acrResource.GetRedisCacheAccessPolicyAssignments();
await foreach (var accessPolicyAssignmentResource in accessPolicyAssignmentCollection)
await foreach (var accessPolicyAssignmentResource in accessPolicyAssignmentCollection.WithCancellation(cancellationToken))
{
if (string.IsNullOrWhiteSpace(accessPolicyAssignmentResource?.Id.ToString())
|| string.IsNullOrWhiteSpace(accessPolicyAssignmentResource.Data.Name))
Expand Down Expand Up @@ -278,7 +278,7 @@ private async Task<IEnumerable<Resource>> ListAmrResourcesAsync(SubscriptionReso
try
{
var databaseCollection = amrResource.GetRedisEnterpriseDatabases();
await foreach (var databaseResource in databaseCollection)
await foreach (var databaseResource in databaseCollection.WithCancellation(cancellationToken))
{
if (string.IsNullOrWhiteSpace(databaseResource?.Id.ToString())
|| string.IsNullOrWhiteSpace(databaseResource.Data.Name))
Expand Down
Loading
Loading