diff --git a/DNS/Client/RequestResolver/ParallelRequestResolver.cs b/DNS/Client/RequestResolver/ParallelRequestResolver.cs new file mode 100644 index 0000000..c6aa418 --- /dev/null +++ b/DNS/Client/RequestResolver/ParallelRequestResolver.cs @@ -0,0 +1,70 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using DNS.Protocol; +using DNS.Protocol.Utils; +using System.Collections.Generic; +using System.Linq; + +namespace DNS.Client.RequestResolver +{ + /// + /// Resolve requests using multiple IRequestResolvers, taking the first result. + /// + public class ParallelRequestResolver : IRequestResolver + { + private List resolvers; + /// + /// Create a new instance of ParallelRequestResolver + /// + /// + /// Thrown when innerResolvers does not contain at least 1 resolver. + public ParallelRequestResolver(IEnumerable innerResolvers) + { + resolvers = innerResolvers.ToList(); + if (resolvers.Count == 0) throw new ArgumentException("No inner DNS resolvers were provided!", nameof(innerResolvers)); + } + + public async Task Resolve(IRequest request, CancellationToken cancellationToken = default) + { + CancellationTokenSource requestCompletedCancellationSource = new CancellationTokenSource(); + var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(requestCompletedCancellationSource.Token, cancellationToken); + List exceptions = new List(); + var tasks = resolvers.Select(i => i.Resolve(request, linkedSource.Token)).ToList(); + bool done = false; + IResponse response = null; + while (response == null) + { + if (tasks.Count == 0) + break; + var completedTask = await Task.WhenAny(tasks).ConfigureAwait(false); + try + { + // We could check the task manually, but this way will handle edge cases. + response = await completedTask.ConfigureAwait(false); + } + catch (Exception ex) + { + exceptions.Add(ex); + } + tasks.Remove(completedTask); + } + + if (tasks.Any()) + { + tasks = tasks.Select(i => i.SwallowExceptions(null)).ToList(); + } + + requestCompletedCancellationSource.Cancel(); + if (response == null) + { + throw new AggregateException(exceptions); + } + + // Should response be wrapped with something that exposes exceptions? + // IE: public class ResponseWithExceptions : IResponse + // new ResponseWithExceptions(response, exceptions) + return response; + } + } +} diff --git a/DNS/Protocol/Utils/TaskExtensions.cs b/DNS/Protocol/Utils/TaskExtensions.cs index ade9066..22e69e4 100644 --- a/DNS/Protocol/Utils/TaskExtensions.cs +++ b/DNS/Protocol/Utils/TaskExtensions.cs @@ -2,26 +2,44 @@ using System.Threading; using System.Threading.Tasks; -namespace DNS.Protocol.Utils { - public static class TaskExtensions { - public static async Task WithCancellation(this Task task, CancellationToken token) { +namespace DNS.Protocol.Utils +{ + public static class TaskExtensions + { + public static async Task WithCancellation(this Task task, CancellationToken token) + { TaskCompletionSource tcs = new TaskCompletionSource(); - CancellationTokenRegistration registration = token.Register(src => { - ((TaskCompletionSource) src).TrySetResult(true); + CancellationTokenRegistration registration = token.Register(src => + { + ((TaskCompletionSource)src).TrySetResult(true); }, tcs); - using (registration) { - if (await Task.WhenAny(task, tcs.Task) != task) { + using (registration) + { + if (await Task.WhenAny(task, tcs.Task) != task) + { throw new OperationCanceledException(token); } } return await task; } - - public static async Task WithCancellationTimeout(this Task task, TimeSpan timeout, CancellationToken cancellationToken = default(CancellationToken)) { + public static async Task SwallowExceptions(this Task task, T defaultValue = default) + { + try + { + return await task; + } + catch + { + return defaultValue; + } + } + public static async Task WithCancellationTimeout(this Task task, TimeSpan timeout, CancellationToken cancellationToken = default(CancellationToken)) + { using (CancellationTokenSource timeoutSource = new CancellationTokenSource(timeout)) - using (CancellationTokenSource linkSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, cancellationToken)) { + using (CancellationTokenSource linkSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, cancellationToken)) + { return await task.WithCancellation(linkSource.Token); } }