Skip to content

Instantly share code, notes, and snippets.

@ChristopherHaws
Last active August 6, 2023 15:31
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChristopherHaws/b1c54b95838f1513bfb74fa1c8e408f3 to your computer and use it in GitHub Desktop.
Save ChristopherHaws/b1c54b95838f1513bfb74fa1c8e408f3 to your computer and use it in GitHub Desktop.
EFCore Azure AccessToken
using System.Collections.Generic;
namespace System.Threading.Tasks
{
public static class AsyncUtilities
{
/// <summary>
/// Execute's an async Task{T} method which has a void return value synchronously
/// </summary>
/// <param name="task">Task{T} method to execute</param>
public static void RunSync(Func<Task> task)
{
var oldContext = SynchronizationContext.Current;
var sync = new ExclusiveSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(sync);
sync.Post(async _ =>
{
try
{
await task();
}
catch (Exception e)
{
sync.InnerException = e;
throw;
}
finally
{
sync.EndMessageLoop();
}
}, null);
sync.BeginMessageLoop();
SynchronizationContext.SetSynchronizationContext(oldContext);
}
/// <summary>
/// Execute's an async Task{T} method which has a T return type synchronously
/// </summary>
/// <typeparam name="T">Return Type</typeparam>
/// <param name="task">Task{T} method to execute</param>
/// <returns></returns>
public static T RunSync<T>(Func<Task<T>> task)
{
var oldContext = SynchronizationContext.Current;
var sync = new ExclusiveSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(sync);
T ret = default;
sync.Post(async _ =>
{
try
{
ret = await task();
}
catch (Exception e)
{
sync.InnerException = e;
throw;
}
finally
{
sync.EndMessageLoop();
}
}, null);
sync.BeginMessageLoop();
SynchronizationContext.SetSynchronizationContext(oldContext);
return ret;
}
private class ExclusiveSynchronizationContext : SynchronizationContext, IDisposable
{
private readonly AutoResetEvent workItemsWaiting = new AutoResetEvent(false);
private readonly Queue<Tuple<SendOrPostCallback, Object>> items = new Queue<Tuple<SendOrPostCallback, Object>>();
private bool done;
public Exception InnerException { get; set; }
public void Dispose()
{
this.workItemsWaiting?.Dispose();
}
public override void Send(SendOrPostCallback d, Object state)
{
throw new NotSupportedException("We cannot send to our same thread");
}
public override void Post(SendOrPostCallback d, Object state)
{
lock (this.items)
{
this.items.Enqueue(Tuple.Create(d, state));
}
this.workItemsWaiting.Set();
}
public void EndMessageLoop()
{
this.Post(_ => this.done = true, null);
}
public void BeginMessageLoop()
{
while (!this.done)
{
Tuple<SendOrPostCallback, object> task = null;
lock (this.items)
{
if (this.items.Count > 0)
{
task = this.items.Dequeue();
}
}
if (task != null)
{
task.Item1(task.Item2);
if (this.InnerException != null) // the method threw an exeption
{
throw new AggregateException("AsyncHelpers.Run method threw an exception.", this.InnerException);
}
}
else
{
this.workItemsWaiting.WaitOne();
}
}
}
public override SynchronizationContext CreateCopy()
{
return this;
}
}
}
}
using System.Data.Common;
using System.Data.SqlClient;
using System.Threading.Tasks;
using Microsoft.Azure.Services.AppAuthentication;
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal;
using Microsoft.EntityFrameworkCore.Storage;
namespace Microsoft.EntityFrameworkCore
{
public static class AzureSqlServerConnectionExtensions
{
public static void UseAzureAccessToken(this DbContextOptionsBuilder options)
{
options.ReplaceService<ISqlServerConnection, AzureSqlServerConnection>();
}
}
public class AzureSqlServerConnection : SqlServerConnection
{
// Compensate for slow SQL Server database creation
private const int DefaultMasterConnectionCommandTimeout = 60;
private static readonly AzureServiceTokenProvider TokenProvider = new AzureServiceTokenProvider();
public AzureSqlServerConnection(RelationalConnectionDependencies dependencies)
: base(dependencies)
{
}
protected override DbConnection CreateDbConnection() => new SqlConnection(this.ConnectionString)
{
// AzureServiceTokenProvider handles caching the token and refreshing it before it expires
AccessToken = AsyncUtilities.RunSync(() => TokenProvider.GetAccessTokenAsync("https://database.windows.net/"))
};
public override ISqlServerConnection CreateMasterConnection()
{
var connectionStringBuilder = new SqlConnectionStringBuilder(this.ConnectionString)
{
InitialCatalog = "master"
};
connectionStringBuilder.Remove("AttachDBFilename");
var contextOptions = new DbContextOptionsBuilder()
.UseSqlServer(
connectionStringBuilder.ConnectionString,
b => b.CommandTimeout(this.CommandTimeout ?? DefaultMasterConnectionCommandTimeout))
.Options;
return new AzureSqlServerConnection(this.Dependencies.With(contextOptions));
}
}
}
public class Startup
{
private readonly IConfiguration configuration;
private readonly IHostingEnvironment env;
public Startup(IConfiguration configuration, IHostingEnvironment env)
{
this.configuration = configuration;
this.env = env;
}
// This method gets called by the runtime. Use this method to add services to the container.
public void ConfigureServices(IServiceCollection services)
{
services.AddDbContextPool<ApplicationContext>(options =>
{
options.UseSqlServer(this.configuration.GetConnectionString("DefaultConnection"));
if (!this.env.IsDevelopment())
{
options.UseAzureAccessToken();
}
});
// Removed unrelated code...
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app)
{
// Removed unrelated code...
}
}
@sebader
Copy link

sebader commented Jul 30, 2019

Thanks a lot for this!
Only one thing: You are using AsyncUtilities here without saying where you got that from. I found this and it works. Was that the one you are using?

@ChristopherHaws
Copy link
Author

@sebader I cant remember where I got it. I feel like I got it from a MS repo at some point, but I don't remember. I updated the gist with the version I am using. Glad it works for you!

@manne
Copy link

manne commented Dec 10, 2019

for the sake of completeness 😊 the nuget package Microsoft.Azure.Services.AppAuthentication is needed

@OskarKlintrot
Copy link

OskarKlintrot commented Jan 22, 2020

Why is it needed to override CreateMasterConnection()?

@ChristopherHaws
Copy link
Author

@OskarKlintrot Because it returns AzureSqlServerConnection instead of SqlServerConnection.

@OskarKlintrot
Copy link

I missed that one, thanks for the clarification!

@ChristopherHaws
Copy link
Author

FYI, to anyone interested, I moved TokenProvider to be a static readonly field so that the caching of tokens works properly.

@OskarKlintrot
Copy link

Did it not work properly before? The token is already cached in a static field.

@ChristopherHaws
Copy link
Author

@OskarKlintrot I was not aware of that, thanks for the info. I suppose all that my update does then is remove a small allocation. ;)

@OskarKlintrot
Copy link

That saves ~28µs on my machine if I remember correctly when I used benchmarkdotnet to see how long time it took to create a new instance :) I ended up using an extension (IDbContextOptionsExtension) to be able to use EF's DI instead and be able to mock it for unit testing purposes. It's probably a lot slower, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment