Skip to content

Instantly share code, notes, and snippets.

@alshain
Last active August 21, 2018 05:40
Show Gist options
  • Save alshain/f080a821db6a4fe974732706d0710fd6 to your computer and use it in GitHub Desktop.
Save alshain/f080a821db6a4fe974732706d0710fd6 to your computer and use it in GitHub Desktop.
IntelliJ LivePlugin inject SQL with custom context in string literals
import com.intellij.ExtensionPoints
import com.intellij.lang.injection.MultiHostInjector
import com.intellij.lang.injection.MultiHostRegistrar
import com.intellij.openapi.extensions.ExtensionPoint
import com.intellij.openapi.extensions.ExtensionPointName
import com.intellij.openapi.extensions.Extensions
import com.intellij.openapi.util.TextRange
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiClassObjectAccessExpression
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiExpression
import com.intellij.psi.PsiField
import com.intellij.psi.PsiJavaCodeReferenceElement
import com.intellij.psi.PsiLanguageInjectionHost
import com.intellij.psi.PsiLiteral
import com.intellij.psi.PsiLiteralExpression
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodCallExpression
import com.intellij.psi.PsiMethodReferenceUtil
import com.intellij.psi.PsiType
import com.intellij.psi.impl.JavaConstantExpressionEvaluator
import com.intellij.psi.impl.source.tree.java.PsiLiteralExpressionImpl
import com.intellij.psi.impl.source.tree.java.PsiMethodCallExpressionImpl
import com.intellij.psi.util.PsiMethodUtil
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.PsiUtil
import com.intellij.sql.psi.SqlLanguage
import com.intellij.util.text.TextRanges
import groovy.transform.Canonical
import org.codehaus.groovy.ast.expr.MethodCallExpression
import org.jetbrains.annotations.NotNull
import static liveplugin.PluginUtil.*
def patterns = [
new Pattern(["getAllWhere"], ["Class<T>", "String", "SystemInfo", "Connection"], withWhere(1, fromClassArgument(0))),
new Pattern(["getExactlyOneWhere", "tryGetExactlyOneWhere", "unsafeGetFirstOrDummy"], ["Class<T>", "String", "SystemInfo", "Connection"], withWhere(1, fromClassArgument(0))),
new Pattern(["getAllWhere"], ["Class<T>", "String", "String", "SystemInfo", "Connection"], withWhereAndOrderBy(1, 2, fromClassArgument(0))),
new Pattern(["getAllWhere"], ["Class<T>", "String", "String", "SystemInfo", "Connection"], withWhereAndOrderBy(1, 2, fromClassArgument(0))),
new Pattern(["readData"], ["String"], withWhere(0, fromReceiver()))
]
@Canonical
class MyCustomMultiHostInjector implements MultiHostInjector {
List<Pattern> patterns;
@Override
void getLanguagesToInject(@NotNull MultiHostRegistrar registrar, @NotNull PsiElement context) {
if (context instanceof PsiMethodCallExpression) {
PsiMethodCallExpression methodExpression = context
def resolvedMethod = methodExpression.resolveMethod()
if (resolvedMethod == null) {
// Not sure when this happens, but we better guard against it
// Maybe thjis happens when we don't have these classes indexed/on the classpath??
// show("Method was null, but we're inside a methodExpression")
// show(methodExpression.text)
// show(methodExpression.containingFile.name)
// 13:56 ButtonDetector.java
//13:56 Method was null, but we're inside a methodExpression
//13:56 MethodSource.from("className", "methodNamePatterns")
//13:56 JUnit5NavigationTest.java
return
}
def name = resolvedMethod.name
// TODO: Doesn't mean we're directly inside a method call, could be nested anywhere...
if (name == null) {
return
}
patterns.forEach({ pattern ->
pattern.methodNamePatterns.forEach({ methodNamePattern ->
if (name.matches(methodNamePattern)) {
// show("Found getAllWhere method")
// show("Parameters " + resolvedMethod.parameters)
// show("parameterList" + resolvedMethod.parameterList)
// show("Location" + resolvedMethod.textRange)
// show("Location" + resolvedMethod.containingFile)
def actualSignature = resolvedMethod.parameterList.parameters*.typeElement*.text.toString()
def expectedSignature = pattern.parameterTypes.toString()
// show("Actual signature: " + actualSignature)
// show("Expected signature: " + expectedSignature)
if (actualSignature == expectedSignature) {
pattern.applyRegistrar.apply(pattern, registrar, resolvedMethod, methodExpression)
}
}
})
})
}
}
@Override
List<? extends Class<? extends PsiElement>> elementsToInjectIn() {
return Collections.singletonList(PsiMethodCallExpression.class)
}
}
@Canonical
class Pattern {
List<String> methodNamePatterns
List<String> parameterTypes
ApplyRegistrar applyRegistrar
}
interface TableProvider {
String getTable(Pattern pattern, MultiHostRegistrar registrar, PsiMethod psiMethod, PsiMethodCallExpression methodCallExpression)
}
TableProvider fromClassArgument(int parameterIndex) {
return { pattern, registrar, psiMethod, methodCallExpression ->
return getTableName(methodCallExpression.argumentList.expressions[parameterIndex])
}
}
PsiType receiverType(PsiMethodCallExpression methodCallExpression) {
def qualifier = methodCallExpression.methodExpression.qualifierExpression
if (qualifier != null) {
return qualifier
}
}
TableProvider fromReceiver() {
return { pattern, registrar, psiMethod, methodCallExpression ->
PsiMethodReferenceUtil.getQualifierType()
methodCallExpression.methodExpression
def psiClass = PsiTreeUtil.getParentOfType(psiMethod, PsiClass.class)
return getTableName(psiClass)
}
}
ApplyRegistrar withWhere(int whereParamIndex, TableProvider tableProvider) {
return { pattern, registrar, psiMethod, methodCallExpression ->
String table = tableProvider.getTable(pattern, registrar, psiMethod, methodCallExpression)
if (table == null) return
List<InjectionPlace> wheres = getPlaces(methodCallExpression.argumentList.expressions[whereParamIndex])
if (wheres.size() == 0) return
boolean isFirst = true;
def select = "SELECT * FROM " + table + " WHERE "
registrar.startInjecting(SqlLanguage.INSTANCE)
for (InjectionPlace where : wheres) {
String prefix = mergePrefixWithFirst(isFirst, select, where)
registrar.addPlace(prefix, where.suffix, where.place, where.textRange)
isFirst = false;
}
registrar.doneInjecting()
}
}
ApplyRegistrar withWhereAndOrderBy(int whereParamIndex, int orderByParamIndex, TableProvider tableProvider) {
return { pattern, registrar, psiMethod, methodCallExpression ->
def classArgument = methodCallExpression.argumentList.expressions[0];
if (classArgument instanceof PsiClassObjectAccessExpression) {
String table = tableProvider.getTable(pattern, registrar, psiMethod, methodCallExpression)
if (table == null) return
List<InjectionPlace> wheres = getPlaces(methodCallExpression.argumentList.expressions[whereParamIndex])
List<InjectionPlace> orderBys = getPlaces(methodCallExpression.argumentList.expressions[orderByParamIndex])
if (wheres.size() + orderBys.size() == 0) return
boolean isFirst = true;
def select = "SELECT * FROM " + table + " WHERE "
registrar.startInjecting(SqlLanguage.INSTANCE)
for (InjectionPlace where : wheres) {
String prefix = mergePrefixWithFirst(isFirst, select, where)
registrar.addPlace(prefix, where.suffix, where.place, where.textRange)
isFirst = false;
}
String orderByPrefix
if (isFirst) {
orderByPrefix = "SELECT * FROM " + table + " ORDER BY "
} else {
orderByPrefix = " ORDER BY "
}
isFirst = true
for (InjectionPlace orderBy : orderBys) {
String prefix = mergePrefixWithFirst(isFirst, orderByPrefix, orderBy)
// show("Order by prefix: " + prefix)
registrar.addPlace(prefix, orderBy.suffix, orderBy.place, orderBy.textRange)
isFirst = false;
}
registrar.doneInjecting()
} else {
show("First paramet is not a class argument")
show(classArgument)
}
}
}
interface ApplyRegistrar {
void apply(Pattern pattern, MultiHostRegistrar registrar, PsiMethod psiMethod, PsiMethodCallExpression methodCallExpression)
}
@Canonical
class InjectionPlace {
PsiLiteralExpression place
TextRange textRange
String prefix
String suffix
}
String getTableName(PsiClass psiClass) {
if (psiClass == null) return null
def found = psiClass.allFields.find { it.name.equals("TABLE") }
if (found != null && found.initializer instanceof PsiLiteralExpression) {
PsiLiteralExpression initializerExpression = found.initializer
def constantInitializer = JavaConstantExpressionEvaluator.computeConstantExpression(initializerExpression, true)
if (constantInitializer instanceof String) {
return constantInitializer
}
}
return null
}
String getTableName(PsiExpression expression) {
// dynamically determine the table name
// by convention, our project has a static TABLE = "tableName" field on some class or super class
if (expression instanceof PsiClassObjectAccessExpression) {
PsiClassObjectAccessExpression objectAccessExpression = expression
def innerMostElement = objectAccessExpression.operand.innermostComponentReferenceElement
def resolved = innerMostElement.resolve()
if (resolved instanceof PsiClass) {
PsiClass psiClass = resolved
return getTableName(psiClass)
} else {
// show("Not a PsiClass")
// show(innerMostElement)
// show(innerMostElement.class)
}
}
return null;
}
List<InjectionPlace> getPlaces(PsiExpression expression) {
List<InjectionPlace> result = []
if (expression instanceof PsiLiteralExpressionImpl) {
result.add(new InjectionPlace(expression, new TextRange(1, expression.textLength - 1), null, null))
}
return result
}
String mergePrefixWithFirst(boolean isFirst, String prefix, InjectionPlace injectionPlace) {
if (isFirst) {
if (prefix == null) {
return injectionPlace.prefix
}
if (injectionPlace.prefix == null) {
return prefix
}
return prefix + injectionPlace.prefix
} else {
return injectionPlace.prefix
}
}
String mergeSuffixWithFirst(boolean isFirst, String suffix, InjectionPlace injectionPlace) {
if (isFirst) {
if (suffix == null) {
return injectionPlace.suffix
}
if (injectionPlace.suffix == null) {
return suffix
}
return suffix + injectionPlace.suffix
} else {
return injectionPlace.suffix
}
}
// Dynamically register MultiHostRegister at runtime
def pointName = MultiHostInjector.MULTIHOST_INJECTOR_EP_NAME
def point = Extensions.getArea(currentProjectInFrame()).getExtensionPoint(pointName)
def length = point.extensions.length
// Try to clean up previous version by comparing the name of all registered injectors against our class name
for (int i = 0; i < length; i++) {
def injector = point.extensions[i]
if (injector.class.name.contains(MyCustomMultiHostInjector.class.simpleName)) {
point.unregisterExtension(injector)
show("Clean up done")
break
}
}
point.registerExtension(new MyCustomMultiHostInjector(patterns))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment