Skip to content

Instantly share code, notes, and snippets.

@d4rken
Last active November 19, 2018 18:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d4rken/7cb60a6a7b315d25766ff15884f9b99b to your computer and use it in GitHub Desktop.
Save d4rken/7cb60a6a7b315d25766ff15884f9b99b to your computer and use it in GitHub Desktop.
package stack.saas.backend.webserver.graphql
import com.expedia.graphql.schema.generator.WiringContext
import com.expedia.graphql.schema.hooks.SchemaGeneratorHooks
import graphql.schema.GraphQLType
import org.bson.types.ObjectId
import stack.saas.backend.common.graphql.coercing.InstantCoercing
import stack.saas.backend.common.graphql.coercing.ObjectIdCoercing
import stack.saas.backend.common.graphql.coercing.UUIDCoercing
import stack.saas.backend.webserver.graphql.components.auth.permission.Permission
import stack.saas.backend.webserver.graphql.components.auth.permission.PermissionCoercing
import java.time.Instant
import java.util.*
import javax.inject.Inject
import kotlin.reflect.KClass
import kotlin.reflect.KType
class CustomSchemaGeneratorHooks @Inject constructor(customWiringFactory: CustomWiringFactory) : SchemaGeneratorHooks {
private val directiveWiringHelper = DirectiveWiringHelper(customWiringFactory)
override fun willGenerateGraphQLType(type: KType): GraphQLType? {
return when (type.classifier as? KClass<*>) {
Instant::class -> InstantCoercing.type
UUID::class -> UUIDCoercing.type
ObjectId::class -> ObjectIdCoercing.type
Permission::class -> PermissionCoercing.type
else -> null
}
}
override fun onRewireGraphQLType(type: KType, generatedType: GraphQLType, context: WiringContext): GraphQLType {
return directiveWiringHelper.onWire(type, generatedType, context)
}
}
package stack.saas.backend.webserver.graphql
import graphql.schema.idl.SchemaDirectiveWiring
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
import graphql.schema.idl.WiringFactory
import stack.saas.backend.common.graphql.DirectiveWiring
import javax.inject.Inject
class CustomWiringFactory @Inject constructor(val wirings: Set<@JvmSuppressWildcards DirectiveWiring>) : WiringFactory {
override fun providesSchemaDirectiveWiring(environment: SchemaDirectiveWiringEnvironment<*>?): Boolean {
return true
}
override fun getSchemaDirectiveWiring(environment: SchemaDirectiveWiringEnvironment<*>): SchemaDirectiveWiring {
return wirings.asSequence()
.filter { it.isResponsible(environment) }
.single()
}
}
package stack.saas.backend.webserver.graphql.components.auth.permission
import graphql.schema.DataFetcher
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
import stack.saas.backend.common.Unauthorized
import stack.saas.backend.common.graphql.DirectiveWiring
import stack.saas.backend.common.graphql.DirectiveWiring.Companion.getDirectiveName
import stack.saas.backend.common.logger
import stack.saas.backend.webserver.graphql.components.auth.ExecutionContext
import javax.inject.Inject
class HasPermissionDirectiveWiring @Inject constructor() : DirectiveWiring {
private val log = logger(this::class)
override fun isResponsible(environment: SchemaDirectiveWiringEnvironment<*>): Boolean {
return environment.directive.name == getDirectiveName(HasUser::class)
}
override fun onField(wiringEnv: SchemaDirectiveWiringEnvironment<GraphQLFieldDefinition>): GraphQLFieldDefinition {
val requiredPermission = wiringEnv.directive.getArgument(HasUser::permission.name).value as Permission
val field = wiringEnv.element
val originalDataFetcher: DataFetcher<Any> = field.dataFetcher
val authDataFetcher = DataFetcher<Any> { dataEnv ->
val user = dataEnv.getContext<ExecutionContext>().getUser()
if (user != null && (requiredPermission == Permission.NONE || user.permissions.contains(requiredPermission))) {
originalDataFetcher.get(dataEnv)
} else {
log.debug("Unauthorized access by $user to $field")
throw Unauthorized()
}
}
return field.transform { it.dataFetcher(authDataFetcher) }
}
}
package stack.saas.backend.webserver.graphql.components.auth.permission
import com.expedia.graphql.annotations.GraphQLDirective
@GraphQLDirective(description = "This resource requires a signed in user with specific permission")
annotation class HasUser(val permission: Permission = Permission.NONE)
package stack.saas.backend.webserver.graphql
import com.expedia.graphql.schema.generator.WiringContext
import graphql.Assert.assertNotNull
import graphql.schema.*
import graphql.schema.idl.SchemaDirectiveWiring
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
import graphql.schema.idl.SchemaDirectiveWiringEnvironmentImpl
import graphql.schema.idl.WiringFactory
import kotlin.reflect.KType
/**
* Based on
* https://github.com/graphql-java/graphql-java/blob/master/src/main/java/graphql/schema/idl/SchemaGeneratorDirectiveHelper.java
*/
class DirectiveWiringHelper(val wiringFactory: WiringFactory) {
private val manualWiring: Map<String, SchemaDirectiveWiring> = mutableMapOf()
fun onWire(type: KType, generatedType: GraphQLType, context: WiringContext): GraphQLType {
return when (generatedType) {
is GraphQLObjectType -> onObject(generatedType, context)
is GraphQLFieldDefinition -> onField(generatedType, context)
is GraphQLInputObjectField -> onInputObjectField(generatedType, context)
is GraphQLInterfaceType -> onInterface(generatedType, context)
is GraphQLUnionType -> onUnion(generatedType, context)
is GraphQLScalarType -> onScalar(generatedType, context)
is GraphQLEnumType -> onEnum(generatedType, context)
is GraphQLEnumValueDefinition -> onEnumValue(generatedType, context)
is GraphQLArgument -> onArgument(generatedType, context)
is GraphQLInputObjectType -> onInputObjectType(generatedType, context)
else -> generatedType
}
}
private fun <T : GraphQLDirectiveContainer> createWiringEnvironment(element: T, directive: GraphQLDirective, context: WiringContext): SchemaDirectiveWiringEnvironment<T> {
return SchemaDirectiveWiringEnvironmentImpl(element, directive, null, null, null)
}
private fun onObject(inputElement: GraphQLObjectType, context: WiringContext): GraphQLObjectType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ wiring, environment -> wiring.onObject(environment) }
)
}
private fun onField(inputElement: GraphQLFieldDefinition, context: WiringContext): GraphQLFieldDefinition {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onField(env) })
}
fun onInterface(inputElement: GraphQLInterfaceType, context: WiringContext): GraphQLInterfaceType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onInterface(env) })
}
fun onUnion(inputElement: GraphQLUnionType, context: WiringContext): GraphQLUnionType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onUnion(env) })
}
fun onScalar(inputElement: GraphQLScalarType, context: WiringContext): GraphQLScalarType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onScalar(env) })
}
fun onEnum(inputElement: GraphQLEnumType, context: WiringContext): GraphQLEnumType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onEnum(env) })
}
fun onEnumValue(inputElement: GraphQLEnumValueDefinition, context: WiringContext): GraphQLEnumValueDefinition {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onEnumValue(env) })
}
fun onArgument(inputElement: GraphQLArgument, context: WiringContext): GraphQLArgument {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onArgument(env) })
}
fun onInputObjectType(inputElement: GraphQLInputObjectType, context: WiringContext): GraphQLInputObjectType {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onInputObjectType(env) })
}
fun onInputObjectField(inputElement: GraphQLInputObjectField, context: WiringContext): GraphQLInputObjectField {
return wireDirectives(inputElement, inputElement.directives,
{ outputElement, directive -> createWiringEnvironment(outputElement, directive, context) },
{ directiveWiring, env -> directiveWiring.onInputObjectField(env) })
}
private fun <T : GraphQLDirectiveContainer> wireDirectives(element: T, directives: List<GraphQLDirective>,
envBuilder: (T, GraphQLDirective) -> SchemaDirectiveWiringEnvironment<T>,
invoker: (SchemaDirectiveWiring, SchemaDirectiveWiringEnvironment<T>) -> T): T {
var outputObject = element
for (directive in directives) {
val env = envBuilder.invoke(outputObject, directive)
val directiveWiring = discoverWiringProvider(directive.name, env)
if (directiveWiring != null) {
val newElement = invoker.invoke(directiveWiring, env)
assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.name + "'")
outputObject = newElement
}
}
return outputObject
}
private fun <T : GraphQLDirectiveContainer> discoverWiringProvider(directiveName: String, env: SchemaDirectiveWiringEnvironment<T>): SchemaDirectiveWiring? {
return if (wiringFactory.providesSchemaDirectiveWiring(env)) {
assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), "You MUST provide a non null SchemaDirectiveWiring")
} else {
manualWiring[directiveName]
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment