diff --git a/DNS/Client/RequestResolver/ParallelRequestResolver.cs b/DNS/Client/RequestResolver/ParallelRequestResolver.cs
new file mode 100644
index 0000000..c6aa418
--- /dev/null
+++ b/DNS/Client/RequestResolver/ParallelRequestResolver.cs
@@ -0,0 +1,70 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using DNS.Protocol;
+using DNS.Protocol.Utils;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace DNS.Client.RequestResolver
+{
+ ///
+ /// Resolve requests using multiple IRequestResolvers, taking the first result.
+ ///
+ public class ParallelRequestResolver : IRequestResolver
+ {
+ private List resolvers;
+ ///
+ /// Create a new instance of ParallelRequestResolver
+ ///
+ ///
+ /// Thrown when innerResolvers does not contain at least 1 resolver.
+ public ParallelRequestResolver(IEnumerable innerResolvers)
+ {
+ resolvers = innerResolvers.ToList();
+ if (resolvers.Count == 0) throw new ArgumentException("No inner DNS resolvers were provided!", nameof(innerResolvers));
+ }
+
+ public async Task Resolve(IRequest request, CancellationToken cancellationToken = default)
+ {
+ CancellationTokenSource requestCompletedCancellationSource = new CancellationTokenSource();
+ var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(requestCompletedCancellationSource.Token, cancellationToken);
+ List exceptions = new List();
+ var tasks = resolvers.Select(i => i.Resolve(request, linkedSource.Token)).ToList();
+ bool done = false;
+ IResponse response = null;
+ while (response == null)
+ {
+ if (tasks.Count == 0)
+ break;
+ var completedTask = await Task.WhenAny(tasks).ConfigureAwait(false);
+ try
+ {
+ // We could check the task manually, but this way will handle edge cases.
+ response = await completedTask.ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ exceptions.Add(ex);
+ }
+ tasks.Remove(completedTask);
+ }
+
+ if (tasks.Any())
+ {
+ tasks = tasks.Select(i => i.SwallowExceptions(null)).ToList();
+ }
+
+ requestCompletedCancellationSource.Cancel();
+ if (response == null)
+ {
+ throw new AggregateException(exceptions);
+ }
+
+ // Should response be wrapped with something that exposes exceptions?
+ // IE: public class ResponseWithExceptions : IResponse
+ // new ResponseWithExceptions(response, exceptions)
+ return response;
+ }
+ }
+}
diff --git a/DNS/Protocol/Utils/TaskExtensions.cs b/DNS/Protocol/Utils/TaskExtensions.cs
index ade9066..22e69e4 100644
--- a/DNS/Protocol/Utils/TaskExtensions.cs
+++ b/DNS/Protocol/Utils/TaskExtensions.cs
@@ -2,26 +2,44 @@
using System.Threading;
using System.Threading.Tasks;
-namespace DNS.Protocol.Utils {
- public static class TaskExtensions {
- public static async Task WithCancellation(this Task task, CancellationToken token) {
+namespace DNS.Protocol.Utils
+{
+ public static class TaskExtensions
+ {
+ public static async Task WithCancellation(this Task task, CancellationToken token)
+ {
TaskCompletionSource tcs = new TaskCompletionSource();
- CancellationTokenRegistration registration = token.Register(src => {
- ((TaskCompletionSource) src).TrySetResult(true);
+ CancellationTokenRegistration registration = token.Register(src =>
+ {
+ ((TaskCompletionSource)src).TrySetResult(true);
}, tcs);
- using (registration) {
- if (await Task.WhenAny(task, tcs.Task) != task) {
+ using (registration)
+ {
+ if (await Task.WhenAny(task, tcs.Task) != task)
+ {
throw new OperationCanceledException(token);
}
}
return await task;
}
-
- public static async Task WithCancellationTimeout(this Task task, TimeSpan timeout, CancellationToken cancellationToken = default(CancellationToken)) {
+ public static async Task SwallowExceptions(this Task task, T defaultValue = default)
+ {
+ try
+ {
+ return await task;
+ }
+ catch
+ {
+ return defaultValue;
+ }
+ }
+ public static async Task WithCancellationTimeout(this Task task, TimeSpan timeout, CancellationToken cancellationToken = default(CancellationToken))
+ {
using (CancellationTokenSource timeoutSource = new CancellationTokenSource(timeout))
- using (CancellationTokenSource linkSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, cancellationToken)) {
+ using (CancellationTokenSource linkSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, cancellationToken))
+ {
return await task.WithCancellation(linkSource.Token);
}
}