Skip to content

Instantly share code, notes, and snippets.

@svenrog
Created January 31, 2022 09:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save svenrog/6fa3dd8f18e91e28ad4900cf2c5c5e16 to your computer and use it in GitHub Desktop.
Save svenrog/6fa3dd8f18e91e28ad4900cf2c5c5e16 to your computer and use it in GitHub Desktop.
Recreated version of Optimizely internal SecurityEntityProvider
using EPiServer.Cms.UI.AspNetIdentity;
using EPiServer.Notification;
using EPiServer.Security;
using EPiServer.ServiceLocation;
using EPiServer.Shell.Security;
using Microsoft.AspNetCore.Identity;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System;
using System.Security.Claims;
using Microsoft.EntityFrameworkCore;
namespace YourProject.Infrastructure.Security
{
public class AspNetIdentitySecurityEntityProvider<TUser> : SecurityEntityProvider, IQueryableNotificationUsers
where TUser : IdentityUser, IUIUser, new()
{
private readonly ServiceAccessor<ApplicationRoleProvider<TUser>> _roleProvider;
private readonly ServiceAccessor<ApplicationUserManager<TUser>> _userManager;
private readonly ServiceAccessor<RoleManager<IdentityRole>> _roleManager;
public AspNetIdentitySecurityEntityProvider(
ServiceAccessor<ApplicationRoleProvider<TUser>> roleProvider,
ServiceAccessor<ApplicationUserManager<TUser>> userManager,
ServiceAccessor<RoleManager<IdentityRole>> roleManager)
{
_roleProvider = roleProvider;
_userManager = userManager;
_roleManager = roleManager;
}
public override async Task<IEnumerable<string>> GetRolesForUserAsync(string userName)
{
return await _roleProvider.Invoke().GetRolesForUserAsync(userName).ToListAsync();
}
public override async Task<(IEnumerable<string> users, int totalCount)> FindUsersInRoleAsync(string roleName, string userName, int startIndex, int maxRows)
{
var isAllUsersQuery = string.IsNullOrEmpty(userName);
var listAsync = await _roleProvider.Invoke()
.GetUsersInRoleAsync(roleName)
.ToListAsync();
IEnumerable<string> users = listAsync;
if (!isAllUsersQuery)
users = users.Where(u => u.IndexOf(userName) > -1);
var count = listAsync.Count;
var page = startIndex > count ? Enumerable.Empty<string>() : users.Skip(startIndex);
return (count - startIndex > maxRows ? page.Take(maxRows) : page, count);
}
public override async Task<IEnumerable<SecurityEntity>> SearchAsync(string partOfValue, string claimType)
{
var (entitites, _) = await SearchAsync(partOfValue, claimType, 0, int.MaxValue);
return entitites;
}
public override async Task<(IEnumerable<SecurityEntity> entities, int totalCount)> SearchAsync(string partOfValue, string claimType, int startIndex, int maxRows)
{
List<SecurityEntity> matchingEntities;
if (claimType == ClaimTypes.Role)
{
var roles = _roleManager.Invoke().Roles;
var matchingRoles = roles;
if (!string.IsNullOrEmpty(partOfValue))
matchingRoles = matchingRoles.Where(r => r.Name.IndexOf(partOfValue) > -1);
matchingEntities = await MapEntities(matchingRoles, startIndex, maxRows, r => r.Name);
}
else
{
var users = _userManager.Invoke().Users;
var matchingUsers = users;
if (!string.IsNullOrEmpty(partOfValue))
{
if (claimType == ClaimTypes.Email)
{
matchingUsers = users.Where(u => u.Email.IndexOf(partOfValue) > -1);
}
else if (claimType == ClaimTypes.Name || claimType == ClaimTypes.NameIdentifier)
{
matchingUsers = users.Where(u => u.UserName.IndexOf(partOfValue) > -1);
}
}
if (matchingUsers == null)
{
matchingEntities = new List<SecurityEntity>(0);
}
else
{
matchingEntities = await MapEntities(matchingUsers, startIndex, maxRows, u => u.UserName);
}
}
return (matchingEntities, matchingEntities.Count);
}
public virtual async Task<PagedNotificationUserResult> FindAsync(string partOfUser, int pageIndex, int pageSize)
{
var users = _userManager.Invoke().Users;
if (users == null)
return new PagedNotificationUserResult(Enumerable.Empty<NotificationUser>(), 0);
var matchingUsers = users.Where(x => x.UserName.IndexOf(partOfUser) > -1);
var matchingResult = await Map(matchingUsers, pageIndex * pageSize, pageSize, x => new NotificationUser(x.UserName));
return new PagedNotificationUserResult(matchingResult, matchingResult.Count);
}
private static Task<List<SecurityEntity>> MapEntities<T>(IQueryable<T> collection, int startIndex, int maxRows, Func<T, string> selector)
{
return Map(collection, startIndex, maxRows, x => new SecurityEntity(selector(x)));
}
private static Task<List<U>> Map<T, U>(IQueryable<T> collection, int startIndex, int maxRows, Func<T, U> selector)
{
var filteredCollection = collection;
if (startIndex > 0)
filteredCollection = filteredCollection.Skip(startIndex);
if (maxRows < int.MaxValue)
filteredCollection = filteredCollection.Take(maxRows);
return filteredCollection.Select(x => selector(x))
.ToListAsync();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment