Skip to content

Instantly share code, notes, and snippets.

@tobiasge
Created March 14, 2023 09:35
Show Gist options
  • Save tobiasge/dd827d8a766894e0a40e7594841c6479 to your computer and use it in GitHub Desktop.
Save tobiasge/dd827d8a766894e0a40e7594841c6479 to your computer and use it in GitHub Desktop.
package q.keycloak.models;
import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.jpa.entities.GroupEntity;
import javax.persistence.EntityManager;
import javax.persistence.TypedQuery;
import java.util.Comparator;
import java.util.Objects;
import java.util.stream.Stream;
import static org.keycloak.models.jpa.PaginationUtils.paginateQuery;
import static org.keycloak.utils.StreamsUtil.closing;
public class QJpaRealmProvider {
protected static final Logger logger = Logger.getLogger(QJpaRealmProvider.class);
private final KeycloakSession session;
protected EntityManager em;
Comparator<GroupEntity> COMPARE_BY_NAME = Comparator.comparing(GroupEntity::getName);
public QJpaRealmProvider(KeycloakSession session, EntityManager em) {
this.session = session;
this.em = em;
}
public GroupEntity getGroupEntityById(RealmModel realm, String id) {
GroupEntity groupEntity = em.find(GroupEntity.class, id);
if (groupEntity == null) return null;
if (!groupEntity.getRealm().equals(realm.getId())) return null;
return groupEntity;
}
public Stream<GroupEntity> getGroupsEntityStream(RealmModel realm) {
return closing(em.createNamedQuery("getGroupIdsByRealm", String.class).setParameter("realm", realm.getId()).getResultStream()).map(g -> this.getGroupEntityById(realm, g));
}
public Stream<GroupEntity> getGroupsEntityStream(RealmModel realm, GroupEntity parent) {
TypedQuery<String> query = em.createNamedQuery("getGroupIdsByParent", String.class);
query.setParameter("parent", parent.getId());
return closing(query.getResultStream().map(g -> this.getGroupEntityById(realm, g)).filter(Objects::nonNull));
}
public Stream<GroupEntity> getTopLevelGroupsEntityStream() {
return getTopLevelGroupsEntityStream(this.session.getContext().getRealm(), null, null);
}
public Stream<GroupEntity> getTopLevelGroupsEntityStream(RealmModel realm, Integer first, Integer max) {
TypedQuery<String> groupsQuery = em.createNamedQuery("getTopLevelGroupIds", String.class).setParameter("realm", realm.getId()).setParameter("parent", GroupEntity.TOP_PARENT_ID);
return closing(paginateQuery(groupsQuery, first, max).getResultStream().map(g -> this.getGroupEntityById(realm, g))
// In concurrent tests, the group might be deleted in another thread, therefore, skip those null values.
.filter(Objects::nonNull));
}
}
package q.keycloak.models;
import org.keycloak.models.GroupModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.jpa.GroupAdapter;
import org.keycloak.models.jpa.entities.GroupEntity;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.Objects;
public class QKeycloakModelUtils {
private static String[] formatPathSegments(String[] segments, int index, String groupName) {
String[] nameSegments = groupName.split(KeycloakModelUtils.GROUP_PATH_SEPARATOR);
if (nameSegments.length > 1 && segments.length >= nameSegments.length) {
for (int i = 0; i < nameSegments.length; i++) {
if (!nameSegments[i].equals(segments[index + i])) {
return segments;
}
}
int numMergedIndexes = nameSegments.length - 1;
String[] newPath = new String[segments.length - numMergedIndexes];
for (int i = 0; i < newPath.length; i++) {
if (i == index) {
newPath[i] = groupName;
} else if (i > index) {
newPath[i] = segments[i + numMergedIndexes];
} else {
newPath[i] = segments[i];
}
}
return newPath;
}
return segments;
}
private static GroupEntity findSubGroup(KeycloakSession session, QJpaRealmProvider qJpaRealmProvider, String[] segments, int index, GroupEntity parent) {
return qJpaRealmProvider.getGroupsEntityStream(session.getContext().getRealm(), parent).map(group -> {
String groupName = group.getName();
String[] pathSegments = formatPathSegments(segments, index, groupName);
if (groupName.equals(pathSegments[index])) {
if (pathSegments.length == index + 1) {
return group;
} else {
if (index + 1 < pathSegments.length) {
GroupEntity found = findSubGroup(session, qJpaRealmProvider, pathSegments, index + 1, group);
return found;
}
}
}
return null;
}).filter(Objects::nonNull).findFirst().orElse(null);
}
public static GroupModel findGroupByPath(KeycloakSession session, String path) {
if (path == null) {
return null;
}
if (path.startsWith(KeycloakModelUtils.GROUP_PATH_SEPARATOR)) {
path = path.substring(1);
}
if (path.endsWith(KeycloakModelUtils.GROUP_PATH_SEPARATOR)) {
path = path.substring(0, path.length() - 1);
}
String[] split = path.split(KeycloakModelUtils.GROUP_PATH_SEPARATOR);
if (split.length == 0) return null;
var qJpaRealmProvider = new QJpaRealmProviderFactory().create(session);
var res = qJpaRealmProvider.getTopLevelGroupsEntityStream().map(group -> {
String groupName = group.getName();
String[] pathSegments = formatPathSegments(split, 0, groupName);
if (groupName.equals(pathSegments[0])) {
if (pathSegments.length == 1) {
return group;
} else {
if (pathSegments.length > 1) {
GroupEntity subGroup = findSubGroup(session, qJpaRealmProvider, pathSegments, 1, group);
return subGroup;
}
}
}
return null;
}).filter(Objects::nonNull).findFirst().orElse(null);
if (res != null) {
return new GroupAdapter(session.getContext().getRealm(), qJpaRealmProvider.em, res);
}
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment