Last active
August 21, 2018 05:40
-
-
Save alshain/f080a821db6a4fe974732706d0710fd6 to your computer and use it in GitHub Desktop.
IntelliJ LivePlugin inject SQL with custom context in string literals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.intellij.ExtensionPoints | |
import com.intellij.lang.injection.MultiHostInjector | |
import com.intellij.lang.injection.MultiHostRegistrar | |
import com.intellij.openapi.extensions.ExtensionPoint | |
import com.intellij.openapi.extensions.ExtensionPointName | |
import com.intellij.openapi.extensions.Extensions | |
import com.intellij.openapi.util.TextRange | |
import com.intellij.psi.PsiClass | |
import com.intellij.psi.PsiClassObjectAccessExpression | |
import com.intellij.psi.PsiElement | |
import com.intellij.psi.PsiExpression | |
import com.intellij.psi.PsiField | |
import com.intellij.psi.PsiJavaCodeReferenceElement | |
import com.intellij.psi.PsiLanguageInjectionHost | |
import com.intellij.psi.PsiLiteral | |
import com.intellij.psi.PsiLiteralExpression | |
import com.intellij.psi.PsiMethod | |
import com.intellij.psi.PsiMethodCallExpression | |
import com.intellij.psi.PsiMethodReferenceUtil | |
import com.intellij.psi.PsiType | |
import com.intellij.psi.impl.JavaConstantExpressionEvaluator | |
import com.intellij.psi.impl.source.tree.java.PsiLiteralExpressionImpl | |
import com.intellij.psi.impl.source.tree.java.PsiMethodCallExpressionImpl | |
import com.intellij.psi.util.PsiMethodUtil | |
import com.intellij.psi.util.PsiTreeUtil | |
import com.intellij.psi.util.PsiUtil | |
import com.intellij.sql.psi.SqlLanguage | |
import com.intellij.util.text.TextRanges | |
import groovy.transform.Canonical | |
import org.codehaus.groovy.ast.expr.MethodCallExpression | |
import org.jetbrains.annotations.NotNull | |
import static liveplugin.PluginUtil.* | |
def patterns = [ | |
new Pattern(["getAllWhere"], ["Class<T>", "String", "SystemInfo", "Connection"], withWhere(1, fromClassArgument(0))), | |
new Pattern(["getExactlyOneWhere", "tryGetExactlyOneWhere", "unsafeGetFirstOrDummy"], ["Class<T>", "String", "SystemInfo", "Connection"], withWhere(1, fromClassArgument(0))), | |
new Pattern(["getAllWhere"], ["Class<T>", "String", "String", "SystemInfo", "Connection"], withWhereAndOrderBy(1, 2, fromClassArgument(0))), | |
new Pattern(["getAllWhere"], ["Class<T>", "String", "String", "SystemInfo", "Connection"], withWhereAndOrderBy(1, 2, fromClassArgument(0))), | |
new Pattern(["readData"], ["String"], withWhere(0, fromReceiver())) | |
] | |
@Canonical | |
class MyCustomMultiHostInjector implements MultiHostInjector { | |
List<Pattern> patterns; | |
@Override | |
void getLanguagesToInject(@NotNull MultiHostRegistrar registrar, @NotNull PsiElement context) { | |
if (context instanceof PsiMethodCallExpression) { | |
PsiMethodCallExpression methodExpression = context | |
def resolvedMethod = methodExpression.resolveMethod() | |
if (resolvedMethod == null) { | |
// Not sure when this happens, but we better guard against it | |
// Maybe thjis happens when we don't have these classes indexed/on the classpath?? | |
// show("Method was null, but we're inside a methodExpression") | |
// show(methodExpression.text) | |
// show(methodExpression.containingFile.name) | |
// 13:56 ButtonDetector.java | |
//13:56 Method was null, but we're inside a methodExpression | |
//13:56 MethodSource.from("className", "methodNamePatterns") | |
//13:56 JUnit5NavigationTest.java | |
return | |
} | |
def name = resolvedMethod.name | |
// TODO: Doesn't mean we're directly inside a method call, could be nested anywhere... | |
if (name == null) { | |
return | |
} | |
patterns.forEach({ pattern -> | |
pattern.methodNamePatterns.forEach({ methodNamePattern -> | |
if (name.matches(methodNamePattern)) { | |
// show("Found getAllWhere method") | |
// show("Parameters " + resolvedMethod.parameters) | |
// show("parameterList" + resolvedMethod.parameterList) | |
// show("Location" + resolvedMethod.textRange) | |
// show("Location" + resolvedMethod.containingFile) | |
def actualSignature = resolvedMethod.parameterList.parameters*.typeElement*.text.toString() | |
def expectedSignature = pattern.parameterTypes.toString() | |
// show("Actual signature: " + actualSignature) | |
// show("Expected signature: " + expectedSignature) | |
if (actualSignature == expectedSignature) { | |
pattern.applyRegistrar.apply(pattern, registrar, resolvedMethod, methodExpression) | |
} | |
} | |
}) | |
}) | |
} | |
} | |
@Override | |
List<? extends Class<? extends PsiElement>> elementsToInjectIn() { | |
return Collections.singletonList(PsiMethodCallExpression.class) | |
} | |
} | |
@Canonical | |
class Pattern { | |
List<String> methodNamePatterns | |
List<String> parameterTypes | |
ApplyRegistrar applyRegistrar | |
} | |
interface TableProvider { | |
String getTable(Pattern pattern, MultiHostRegistrar registrar, PsiMethod psiMethod, PsiMethodCallExpression methodCallExpression) | |
} | |
TableProvider fromClassArgument(int parameterIndex) { | |
return { pattern, registrar, psiMethod, methodCallExpression -> | |
return getTableName(methodCallExpression.argumentList.expressions[parameterIndex]) | |
} | |
} | |
PsiType receiverType(PsiMethodCallExpression methodCallExpression) { | |
def qualifier = methodCallExpression.methodExpression.qualifierExpression | |
if (qualifier != null) { | |
return qualifier | |
} | |
} | |
TableProvider fromReceiver() { | |
return { pattern, registrar, psiMethod, methodCallExpression -> | |
PsiMethodReferenceUtil.getQualifierType() | |
methodCallExpression.methodExpression | |
def psiClass = PsiTreeUtil.getParentOfType(psiMethod, PsiClass.class) | |
return getTableName(psiClass) | |
} | |
} | |
ApplyRegistrar withWhere(int whereParamIndex, TableProvider tableProvider) { | |
return { pattern, registrar, psiMethod, methodCallExpression -> | |
String table = tableProvider.getTable(pattern, registrar, psiMethod, methodCallExpression) | |
if (table == null) return | |
List<InjectionPlace> wheres = getPlaces(methodCallExpression.argumentList.expressions[whereParamIndex]) | |
if (wheres.size() == 0) return | |
boolean isFirst = true; | |
def select = "SELECT * FROM " + table + " WHERE " | |
registrar.startInjecting(SqlLanguage.INSTANCE) | |
for (InjectionPlace where : wheres) { | |
String prefix = mergePrefixWithFirst(isFirst, select, where) | |
registrar.addPlace(prefix, where.suffix, where.place, where.textRange) | |
isFirst = false; | |
} | |
registrar.doneInjecting() | |
} | |
} | |
ApplyRegistrar withWhereAndOrderBy(int whereParamIndex, int orderByParamIndex, TableProvider tableProvider) { | |
return { pattern, registrar, psiMethod, methodCallExpression -> | |
def classArgument = methodCallExpression.argumentList.expressions[0]; | |
if (classArgument instanceof PsiClassObjectAccessExpression) { | |
String table = tableProvider.getTable(pattern, registrar, psiMethod, methodCallExpression) | |
if (table == null) return | |
List<InjectionPlace> wheres = getPlaces(methodCallExpression.argumentList.expressions[whereParamIndex]) | |
List<InjectionPlace> orderBys = getPlaces(methodCallExpression.argumentList.expressions[orderByParamIndex]) | |
if (wheres.size() + orderBys.size() == 0) return | |
boolean isFirst = true; | |
def select = "SELECT * FROM " + table + " WHERE " | |
registrar.startInjecting(SqlLanguage.INSTANCE) | |
for (InjectionPlace where : wheres) { | |
String prefix = mergePrefixWithFirst(isFirst, select, where) | |
registrar.addPlace(prefix, where.suffix, where.place, where.textRange) | |
isFirst = false; | |
} | |
String orderByPrefix | |
if (isFirst) { | |
orderByPrefix = "SELECT * FROM " + table + " ORDER BY " | |
} else { | |
orderByPrefix = " ORDER BY " | |
} | |
isFirst = true | |
for (InjectionPlace orderBy : orderBys) { | |
String prefix = mergePrefixWithFirst(isFirst, orderByPrefix, orderBy) | |
// show("Order by prefix: " + prefix) | |
registrar.addPlace(prefix, orderBy.suffix, orderBy.place, orderBy.textRange) | |
isFirst = false; | |
} | |
registrar.doneInjecting() | |
} else { | |
show("First paramet is not a class argument") | |
show(classArgument) | |
} | |
} | |
} | |
interface ApplyRegistrar { | |
void apply(Pattern pattern, MultiHostRegistrar registrar, PsiMethod psiMethod, PsiMethodCallExpression methodCallExpression) | |
} | |
@Canonical | |
class InjectionPlace { | |
PsiLiteralExpression place | |
TextRange textRange | |
String prefix | |
String suffix | |
} | |
String getTableName(PsiClass psiClass) { | |
if (psiClass == null) return null | |
def found = psiClass.allFields.find { it.name.equals("TABLE") } | |
if (found != null && found.initializer instanceof PsiLiteralExpression) { | |
PsiLiteralExpression initializerExpression = found.initializer | |
def constantInitializer = JavaConstantExpressionEvaluator.computeConstantExpression(initializerExpression, true) | |
if (constantInitializer instanceof String) { | |
return constantInitializer | |
} | |
} | |
return null | |
} | |
String getTableName(PsiExpression expression) { | |
// dynamically determine the table name | |
// by convention, our project has a static TABLE = "tableName" field on some class or super class | |
if (expression instanceof PsiClassObjectAccessExpression) { | |
PsiClassObjectAccessExpression objectAccessExpression = expression | |
def innerMostElement = objectAccessExpression.operand.innermostComponentReferenceElement | |
def resolved = innerMostElement.resolve() | |
if (resolved instanceof PsiClass) { | |
PsiClass psiClass = resolved | |
return getTableName(psiClass) | |
} else { | |
// show("Not a PsiClass") | |
// show(innerMostElement) | |
// show(innerMostElement.class) | |
} | |
} | |
return null; | |
} | |
List<InjectionPlace> getPlaces(PsiExpression expression) { | |
List<InjectionPlace> result = [] | |
if (expression instanceof PsiLiteralExpressionImpl) { | |
result.add(new InjectionPlace(expression, new TextRange(1, expression.textLength - 1), null, null)) | |
} | |
return result | |
} | |
String mergePrefixWithFirst(boolean isFirst, String prefix, InjectionPlace injectionPlace) { | |
if (isFirst) { | |
if (prefix == null) { | |
return injectionPlace.prefix | |
} | |
if (injectionPlace.prefix == null) { | |
return prefix | |
} | |
return prefix + injectionPlace.prefix | |
} else { | |
return injectionPlace.prefix | |
} | |
} | |
String mergeSuffixWithFirst(boolean isFirst, String suffix, InjectionPlace injectionPlace) { | |
if (isFirst) { | |
if (suffix == null) { | |
return injectionPlace.suffix | |
} | |
if (injectionPlace.suffix == null) { | |
return suffix | |
} | |
return suffix + injectionPlace.suffix | |
} else { | |
return injectionPlace.suffix | |
} | |
} | |
// Dynamically register MultiHostRegister at runtime | |
def pointName = MultiHostInjector.MULTIHOST_INJECTOR_EP_NAME | |
def point = Extensions.getArea(currentProjectInFrame()).getExtensionPoint(pointName) | |
def length = point.extensions.length | |
// Try to clean up previous version by comparing the name of all registered injectors against our class name | |
for (int i = 0; i < length; i++) { | |
def injector = point.extensions[i] | |
if (injector.class.name.contains(MyCustomMultiHostInjector.class.simpleName)) { | |
point.unregisterExtension(injector) | |
show("Clean up done") | |
break | |
} | |
} | |
point.registerExtension(new MyCustomMultiHostInjector(patterns)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment