Skip to content

Instantly share code, notes, and snippets.

@liemle3893
Created February 7, 2018 06:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save liemle3893/b32e2a568f6522e1faa6f79e76a4443e to your computer and use it in GitHub Desktop.
Save liemle3893/b32e2a568f6522e1faa6f79e76a4443e to your computer and use it in GitHub Desktop.
Production ready User-Role-Privilege problem.

Privilege.java

    // Privilege.java
    import javax.persistence.*;
    import java.io.Serializable;
    import java.util.Set;
    
    
    @Setter
    @Getter
    @EqualsAndHashCode(exclude = {"roles"})
    @ToString(exclude = {"roles"})
    @NoArgsConstructor
    @Builder
    @AllArgsConstructor
    //
    @Entity
    @Table(name = "cms_privilege")
    public class Privilege implements Serializable {
        static final long serialVersionUID = -12393123123L;
        @Id
        @GeneratedValue(strategy = GenerationType.AUTO)
        @Column(name = "id")
        private Long id;
        private String name;
    
        @ManyToMany(mappedBy = "privileges", cascade = {CascadeType.MERGE, CascadeType.PERSIST, CascadeType.DETACH, CascadeType.REMOVE})
        private Set<Role> roles;
    }

Role.java

    // Role.java
    import lombok.*;
    import org.hibernate.annotations.Fetch;
    import org.hibernate.annotations.FetchMode;
    
    import javax.persistence.*;
    import javax.validation.constraints.NotNull;
    import java.io.Serializable;
    import java.util.Set;
    
    @Setter
    @Getter
    @EqualsAndHashCode(exclude = {"privileges", "users"})
    @ToString(exclude = {"privileges", "users"})
    @NoArgsConstructor
    @Builder
    @AllArgsConstructor
    //
    @Entity
    @Table(name = "cms_role")
    public class Role implements Serializable {
        static final long serialVersionUID = 1283712837L;
        @Id
        @GeneratedValue(strategy = GenerationType.AUTO)
        @Column(name = "id")
        private Long id;
    
        @Column(name = "name")
        @NotNull
        private String name;
    
        @Column(name = "home_page")
        private String homePage;
    
        @ManyToMany(mappedBy = "roles")
        private Set<User> users;
        @ManyToMany(fetch = FetchType.EAGER)
        @JoinTable(name = "cms_roles_privileges",
                joinColumns = @JoinColumn(name = "role_id", referencedColumnName = "id"),
                inverseJoinColumns = @JoinColumn(name = "privilege_id", referencedColumnName = "id")
        )
        @Fetch(FetchMode.JOIN)
        private Set<Privilege> privileges;
    }

User.java

    import com.google.common.collect.Sets;
    import lombok.*;
    
    import javax.persistence.*;
    import javax.validation.constraints.NotNull;
    import java.io.Serializable;
    import java.util.Arrays;
    import java.util.List;
    import java.util.Set;
    import java.util.stream.Collectors;
    
    @Setter
    @Getter
    @EqualsAndHashCode(exclude = {"roles"})
    @ToString(exclude = {"roles"})
    @NoArgsConstructor
    @Builder
    @AllArgsConstructor
    //
    @Entity
    @Table(name = "cms_user", uniqueConstraints = {
            @UniqueConstraint(columnNames = "username", name = "unq_cms_user_username")
    })
    public class User implements Serializable {
        static final long serialVersionUID = -234789L;
        @Id
        @Column(name = "id")
        @GeneratedValue(strategy = GenerationType.AUTO)
        private Long id;
    
        @Column(name = "username", unique = true, nullable = false)
        @NotNull
        private String username;
    
        @Column(name = "password", nullable = false)
        @NotNull
        private String password;
    
        @Column(name = "status", nullable = false)
        private Integer status;
        @Column(name = "first_name")
        private String firstName;
        @Column(name = "last_name")
        private String lastName;
    
        @Column(name = "description")
        private String description;
    
        @Column(name = "user_type")
        private Integer userType;
    
        @ManyToMany(fetch = FetchType.EAGER, cascade = {CascadeType.MERGE, CascadeType.PERSIST})
        @JoinTable(name = "cms_users_roles",
                joinColumns = @JoinColumn(name = "user_id", referencedColumnName = "id"),
                inverseJoinColumns = @JoinColumn(name = "role_id", referencedColumnName = "id")
        )
        private Set<Role> roles;
    
        @Column(name = "phone_number")
        private String phone;
    
        @Transient
        public String displayName() {
            return this.getFirstName() + " " + this.getLastName();
        }
    
    
        @Transient
        public boolean hasAnyRoles(String... roles) {
            return hasAnyRoles(Arrays.asList(roles));
        }
    
        @Transient
        public boolean hasAnyRoles(List<String> roles) {
            Set<String> _roles = this.getRoles()
                    .stream()
                    .map(Role::getName)
                    .collect(Collectors.toSet());
            Sets.SetView<String> intersection = Sets.intersection(Sets.newHashSet(roles), _roles);
            return !intersection.isEmpty();
        }
    
        @Transient
        public boolean hasRoles(String... roles) {
            return hasRoles(Arrays.asList(roles));
        }
    
        @Transient
        public boolean hasRoles(List<String> roles) {
            Set<String> _roles = this.getRoles()
                    .stream()
                    .map(Role::getName)
                    .collect(Collectors.toSet());
            return _roles.containsAll(roles);
        }
    
        @Transient
        public boolean hasAnyPrivileges(String... privileges) {
            return hasAnyPrivileges(Arrays.asList(privileges));
        }
    
        @Transient
        public boolean hasAnyPrivileges(List<String> privileges) {
            Set<String> _privileges = this.getRoles()
                    .stream()
                    .flatMap(s -> s.getPrivileges().stream())
                    .map(Privilege::getName)
                    .collect(Collectors.toSet());
            Sets.SetView<String> intersection = Sets.intersection(_privileges, Sets.newHashSet(privileges));
            return !intersection.isEmpty();
        }
    
        @Transient
        public boolean hasPrivileges(String... privileges) {
            return hasPrivileges(Arrays.asList(privileges));
        }
    
        @Transient
        public boolean hasPrivileges(List<String> privileges) {
            Set<String> _privileges = this.getRoles()
                    .stream()
                    .flatMap(s -> s.getPrivileges().stream())
                    .map(Privilege::getName)
                    .collect(Collectors.toSet());
            return _privileges.containsAll(privileges);
        }
    }

This one will work in any case (not only Spring) If your using some sort of AOP (Spring-aop for example) You can do something like this:

Aspect.java

    @org.aspectj.lang.annotation.Aspect
    @Component
    @Slf4j
    public class Aspect {
    
    
        @Autowired
        UserActionRepository userActionRepository;
    
        @Around("@annotation(RequiredPrivilege)")
        public Object preCheckPrivilege(ProceedingJoinPoint joinPoint) throws Throwable {
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            Method method = signature.getMethod();
            RequiredPrivilege requiredPrivilege = method.getAnnotation(RequiredPrivilege.class);
            String[] privileges = requiredPrivilege.privileges().length == 0 ?
                    requiredPrivilege.value() : requiredPrivilege.privileges();
            if (privileges.length == 0) {
                return joinPoint.proceed();
            } else if (requiredPrivilege.anyMatch()) {
                LoggedInUsers.checkAnyPrivileges(privileges);
            } else {
                LoggedInUsers.checkPrivileges(privileges);
            }
            for (Object obj : joinPoint.getArgs()) {
                if (obj instanceof Model) {
                    ((Model) (obj)).addAttribute("_user", LoggedInUsers.currentUser());
                }
            }
            return joinPoint.proceed();
        }
    
        @Around("@annotation(RequiredRole)")
        public Object preCheckRole(ProceedingJoinPoint joinPoint) throws Throwable {
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            Method method = signature.getMethod();
            RequiredRole requiredRoles = method.getAnnotation(RequiredRole.class);
            String[] roles = requiredRoles.roles().length == 0 ?
                    requiredRoles.value() : requiredRoles.roles();
            if (roles.length == 0) {
                return joinPoint.proceed();
            } else if (requiredRoles.anyMatch()) {
                LoggedInUsers.checkAnyRoles(roles);
            } else {
                LoggedInUsers.checkRoles(roles);
            }
            return joinPoint.proceed();
        }
    
        @Around("@annotation(TrackingAction)")
        public Object trackUserAction(ProceedingJoinPoint joinPoint) throws Throwable {
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            Method method = signature.getMethod();
            TrackingAction trackingAction = method.getAnnotation(TrackingAction.class);
            String clazzName = method.getDeclaringClass().getSimpleName();
            String actionName = trackingAction.actionName();
            String methodName = method.getName();
            String desc;
            if (!Strings.isNullOrEmpty(trackingAction.descriptionFormat())) {
                Map<String, Object> map = new HashMap<>();
                map.put("username", LoggedInUsers.currentUser().getUsername());
                for (int i = 0; i < joinPoint.getArgs().length; i++) {
                    map.put("" + i, joinPoint.getArgs()[i]);
                }
                String descriptionFormat = trackingAction.descriptionFormat();
                StrSubstitutor substitutor = new StrSubstitutor(map);
                desc = substitutor.replace(descriptionFormat);
            } else {
                String parameters = Arrays.toString(joinPoint.getArgs());
                desc = clazzName + "::" + methodName + "(" + parameters + ")";
            }
            if (Strings.isNullOrEmpty(actionName)) {
                actionName = clazzName + "#" + methodName;
            }
    
            UserAction userAction = UserAction.builder()
                    .timestamp(new Timestamp(System.currentTimeMillis()))
                    .user(LoggedInUsers.currentUser())
                    .type(trackingAction.actionType().name())
                    .actionName(actionName)
                    .description(desc)
                    .build();
            try {
                userActionRepository.save(userAction);
            } catch (Exception ex) {
                log.error("Cannot insert userAction: {}", userAction, ex);
            } finally {
                return joinPoint.proceed();
            }
        }
    
    
    }

RequiredPrivilege.java

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RequiredPrivilege {
    String[] value();

    // alternative for value
    String[] privileges() default {};

    boolean anyMatch() default true;
}

RequiredRole.java

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RequiredRole {
    String[] value();

    // alternative for value
    String[] roles() default {};

    boolean anyMatch() default true;
}

TrackingAction.java

import org.aspectj.lang.ProceedingJoinPoint;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @see vn.vng.wpl.livestream.admin.cms.aop.Aspect#trackUserAction(ProceedingJoinPoint)
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface TrackingAction {
    UserActionType actionType();

    String actionName() default "";

    /**
     * Use with {@link org.apache.commons.lang3.text.StrSubstitutor}
     *
     * @return
     */
    String descriptionFormat() default "";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment