Last active
August 8, 2017 17:38
-
-
Save MatthewVines/f718a74d411aaa97046a50ef13909da6 to your computer and use it in GitHub Desktop.
NPoco Interceptor to inject tenant checks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Data.Common; | |
using System.Linq; | |
using System.Text; | |
using Microsoft.SqlServer.TransactSql.ScriptDom; | |
using NPoco; | |
public class SystemCheckQueryInterceptor : IExecutingInterceptor | |
{ | |
private readonly int _systemId; | |
private readonly List<string> _nonTenantTables; | |
public SystemCheckQueryInterceptor(int systemId, List<string> nonTenantTables) | |
{ | |
_systemId = systemId; | |
List<string> standardizedTables = new List<string>(); | |
// make sure table names are bracketted. | |
foreach (string table in nonTenantTables) | |
{ | |
StringBuilder sb = new StringBuilder(); | |
string[] tokens = table.Split('.'); | |
for (int i = 0; i < tokens.Length; i++) | |
{ | |
if (!tokens[i].StartsWith("[")) | |
{ | |
sb.Append("["); | |
} | |
sb.Append(tokens[i].ToLower()); | |
if (!tokens[i].EndsWith("]")) | |
{ | |
sb.Append("]"); | |
} | |
if (tokens[i] != tokens[tokens.Length - 1]) | |
{ | |
sb.Append("."); | |
} | |
} | |
standardizedTables.Add(sb.ToString()); | |
} | |
_nonTenantTables = standardizedTables; | |
} | |
public void OnExecutingCommand(IDatabase database, DbCommand cmd) | |
{ | |
cmd.CommandText = AddTenantChecksToStatement(cmd.CommandText); | |
} | |
public void OnExecutedCommand(IDatabase database, DbCommand cmd) | |
{ | |
return; | |
} | |
internal string AddTenantChecksToStatement(string sql) | |
{ | |
List<string> parseErrors; | |
List<TSqlParserToken> queryTokens = TokenizeSql(sql, out parseErrors); | |
if (parseErrors == null) | |
{ | |
Dictionary<string, string> tableAliases = GetTableNamesFromQueryString(queryTokens); | |
string tenantClause = BuildTenantClause(tableAliases); | |
string newSql = AppendClause(sql, queryTokens, tenantClause, Concatenator.And); | |
return newSql; | |
} | |
// If we had parse errors, there is probably something wrong with the statement. | |
// The best errors will come back from the db itself, so just pass the statement on as is. | |
return sql; | |
} | |
private List<TSqlParserToken> TokenizeSql(string sql, out List<string> parserErrors) | |
{ | |
using (System.IO.TextReader tReader = new System.IO.StringReader(sql)) | |
{ | |
var parser = new TSql120Parser(true); | |
IList<ParseError> errors; | |
var queryTokens = parser.GetTokenStream(tReader, out errors); | |
if (errors.Any()) | |
{ | |
parserErrors = errors.Select(e =>$"Error: {e.Number}; Line: {e.Line}; Column: {e.Column}; Offset: {e.Offset}; Message: {e.Message};").ToList(); | |
} | |
else | |
{ | |
parserErrors = null; | |
} | |
return queryTokens.ToList(); | |
} | |
} | |
private Dictionary<string, string> GetTableNamesFromQueryString(List<TSqlParserToken> queryTokens) | |
{ | |
var output = new Dictionary<string, string>(); | |
var sb = new StringBuilder(); | |
var fromTokenTypes = new[] | |
{ | |
TSqlTokenType.From, | |
TSqlTokenType.Join | |
}; | |
var identifierTokenTypes = new[] | |
{ | |
TSqlTokenType.Identifier, | |
TSqlTokenType.QuotedIdentifier | |
}; | |
for (var i = 0; i < queryTokens.Count; i++) | |
{ | |
if (fromTokenTypes.Contains(queryTokens[i].TokenType)) | |
{ | |
for (var j = i + 1; j < queryTokens.Count; j++) | |
{ | |
if (queryTokens[j].TokenType == TSqlTokenType.WhiteSpace) | |
{ | |
continue; | |
} | |
if (identifierTokenTypes.Contains(queryTokens[j].TokenType)) | |
{ | |
sb.Clear(); | |
GetQuotedIdentifier(queryTokens[j], sb); | |
while (j + 2 < queryTokens.Count | |
&& queryTokens[j + 1].TokenType == TSqlTokenType.Dot | |
&& (queryTokens[j + 2].TokenType == TSqlTokenType.Dot || identifierTokenTypes.Contains(queryTokens[j + 2].TokenType))) | |
{ | |
sb.Append(queryTokens[j + 1].Text.ToLower()); | |
if (queryTokens[j + 2].TokenType == TSqlTokenType.Dot) | |
{ | |
if (queryTokens[j - 1].TokenType == TSqlTokenType.Dot) | |
GetQuotedIdentifier(queryTokens[j + 1], sb); | |
j++; | |
} | |
else | |
{ | |
GetQuotedIdentifier(queryTokens[j + 2], sb); | |
j += 2; | |
} | |
} | |
string tableName = sb.ToString(); | |
string alias = tableName; | |
if (j + 2 < queryTokens.Count | |
&& identifierTokenTypes.Contains(queryTokens[j].TokenType) | |
&& queryTokens[j + 1].TokenType == TSqlTokenType.WhiteSpace | |
&& identifierTokenTypes.Contains(queryTokens[j + 2].TokenType)) | |
{ | |
alias = queryTokens[j + 2].Text; | |
} | |
output.Add(alias, tableName); | |
} | |
break; | |
} | |
} | |
} | |
return output; | |
} | |
private void GetQuotedIdentifier(TSqlParserToken token, StringBuilder sb) | |
{ | |
switch (token.TokenType) | |
{ | |
case TSqlTokenType.Identifier: | |
sb.Append('[').Append(token.Text.ToLower()).Append(']'); | |
break; | |
case TSqlTokenType.QuotedIdentifier: | |
case TSqlTokenType.Dot: | |
sb.Append(token.Text.ToLower()); | |
break; | |
default: throw new ArgumentException("Error: expected TokenType of token should be TSqlTokenType.Dot, TSqlTokenType.Identifier, or TSqlTokenType.QuotedIdentifier"); | |
} | |
} | |
private bool IsTenantTable(string tableName) | |
{ | |
foreach (string nonTenantTable in _nonTenantTables) | |
{ | |
if (nonTenantTable.Contains(tableName.Split(new[] {'.'}).Last())) | |
{ | |
return false; | |
} | |
} | |
return true; | |
} | |
private string BuildTenantClause(Dictionary<string, string> tableAliases) | |
{ | |
StringBuilder sb = new StringBuilder(); | |
bool firstTenantTable = true; | |
foreach (KeyValuePair<string, string> kvp in tableAliases) | |
{ | |
if (IsTenantTable(kvp.Value)) | |
{ | |
if (firstTenantTable) | |
{ | |
sb.Append("("); | |
firstTenantTable = false; | |
} | |
else | |
{ | |
sb.Append(" AND "); | |
} | |
sb.Append($"{kvp.Key}.system_id = {_systemId}"); | |
} | |
} | |
if (sb.Length > 0) | |
{ | |
sb.Append(")"); | |
} | |
return sb.ToString(); | |
} | |
private enum Concatenator | |
{ | |
And, | |
Or | |
} | |
private string AppendClause(string sql, List<TSqlParserToken> queryTokens, string newClause, Concatenator concatenator) | |
{ | |
int indexOfWhere = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Where); | |
int indexOfGroup = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Group); | |
int indexOfOrder = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Order); | |
bool hasWhere = indexOfWhere >= 0; | |
bool hasGroup = indexOfGroup >= 0; | |
bool hasOrder = indexOfOrder >= 0; | |
int orderOffset = (hasOrder) ? queryTokens[indexOfOrder].Offset : Int32.MinValue; | |
int groupOffset = (hasGroup) ? queryTokens[indexOfGroup].Offset : Int32.MinValue; | |
if (!hasWhere && !hasGroup && !hasOrder) | |
{ | |
return sql + $" WHERE {newClause}"; | |
} | |
if (!hasWhere && hasGroup) | |
{ | |
return sql.Insert(groupOffset, $" WHERE {newClause} "); | |
} | |
if (!hasWhere && !hasGroup && hasOrder) | |
{ | |
return sql.Insert(orderOffset, $" WHERE {newClause} "); | |
} | |
if (hasWhere && !hasGroup && !hasOrder) | |
{ | |
return sql + $" {concatenator.ToString().ToUpper()} {newClause}"; | |
} | |
if (hasWhere && hasGroup) | |
{ | |
return sql.Insert(groupOffset, $" {concatenator.ToString().ToUpper()} {newClause} "); | |
} | |
if (hasWhere && !hasGroup && hasOrder) | |
{ | |
return sql.Insert(orderOffset, $" {concatenator.ToString().ToUpper()} {newClause} "); | |
} | |
return null; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using NUnit.Framework; | |
using xray_svc; | |
[TestFixture] | |
internal class SystemCheckQueryInterceptorTests | |
{ | |
private int _systemId; | |
private List<string> _nonTenantTables; | |
[SetUp] | |
public void Setup() | |
{ | |
_systemId = 450; | |
_nonTenantTables = new List<string>(); | |
} | |
//Used to make the assertions more resilient | |
private static string RemoveWhitespace(string s) | |
{ | |
return s.Replace(" ", "").Replace(Environment.NewLine, ""); | |
} | |
[Test] | |
public void QueryWithoutWhereClause_AddWhereClauseWithTenantClause() | |
{ | |
string sql = @"SELECT TOP 10 * FROM tableA"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * FROM tableA WHERE ([tablea].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithWhereClause_AddTenantClauseWithAnd() | |
{ | |
string sql = @"SELECT TOP 10 * FROM tableA WHERE x = y"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * FROM tableA WHERE x = y AND ([tablea].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithJoins_AddTenantClauseForEachTable() | |
{ | |
string sql = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN tableB on tableA.x = tableB.x | |
JOIN tableC on tableA.x = tableC.x | |
WHERE tablaA.x = 'y'"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN tableB on tableA.x = tableB.x | |
JOIN tableC on tableA.x = tableC.x | |
WHERE tablaA.x = 'y' | |
AND ([tablea].system_id = 450 | |
AND [tableb].system_id = 450 | |
AND [tablec].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void FullyQualifiedTables_WorkAsExpected() | |
{ | |
string sql = @"SELECT TOP 10 * | |
FROM [db].[sch].[tableA] | |
JOIN xxx.tableB on tableA.x = xxx.tableB.x | |
JOIN tableC on db.sch.tableA.x = tableC.x | |
WHERE [db].[sch].[tableA].x = 'y'"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM [db].[sch].[tableA] | |
JOIN xxx.tableB on tableA.x = xxx.tableB.x | |
JOIN tableC on db.sch.tableA.x = tableC.x | |
WHERE [db].[sch].[tableA].x = 'y' | |
AND ([db].[sch].[tablea].system_id = 450 | |
AND [xxx].[tableb].system_id = 450 | |
AND [tablec].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithAliases_AliasesAreUsedRatherThanTableNames() | |
{ | |
string sql = @"SELECT TOP 10 * | |
FROM [db].[sch].[tableA] A | |
JOIN xxx.tableB on A.x = xxx.tableB.x | |
JOIN tableA A2 on A.x = A2.x | |
WHERE A.x = 'y' AND A2.x = 'yy'"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM [db].[sch].[tableA] A | |
JOIN xxx.tableB on A.x = xxx.tableB.x | |
JOIN tableA A2 on A.x = A2.x | |
WHERE A.x = 'y' AND A2.x = 'yy' | |
AND (A.system_id = 450 | |
AND [xxx].[tableb].system_id = 450 | |
AND A2.system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithNonTenantTables_OnlyAddClausesForTenantTables() | |
{ | |
_nonTenantTables.Add("tableB"); | |
_nonTenantTables.Add("tableD"); | |
string sql = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN tableB on tableA.x = tableB.x | |
JOIN tableC on tableB.x = tableC.x | |
JOIN tableD on tableC.x = tableD.x | |
WHERE tablaA.x = 'y'"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN tableB on tableA.x = tableB.x | |
JOIN tableC on tableB.x = tableC.x | |
JOIN tableD on tableC.x = tableD.x | |
WHERE tablaA.x = 'y' | |
AND ([tablea].system_id = 450 | |
AND [tablec].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithSchemaDefinedNonTenantTables_WorksAsExpected() | |
{ | |
_nonTenantTables.Add("xxx.tableB"); | |
_nonTenantTables.Add("db.xxx.tableD"); | |
string sql = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN xxx.tableB on tableA.x = xxx.tableB.x | |
JOIN tableC on tableB.x = tableC.x | |
JOIN db.xxx.tableD on tableC.x = db.xxx.tableD.x | |
WHERE tablaA.x = 'y'"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM tableA | |
JOIN xxx.tableB on tableA.x = xxx.tableB.x | |
JOIN tableC on tableB.x = tableC.x | |
JOIN db.xxx.tableD on tableC.x = db.xxx.tableD.x | |
WHERE tablaA.x = 'y' | |
AND ([tablea].system_id = 450 | |
AND [tablec].system_id = 450)"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithoutWhereClauseWithGroupByClause_WorksAsExpected() | |
{ | |
string sql = @"SELECT product_id, AVG(sale_price) | |
FROM Sales | |
GROUP BY product_id"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT product_id, AVG(sale_price) | |
FROM Sales | |
WHERE ([sales].system_id = 450) | |
GROUP BY product_id"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithoutWhereClauseWithOrderByClause_WorksAsExpected() | |
{ | |
string sql = @"SELECT TOP 10 * | |
FROM tableA | |
ORDER BY id DESC"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM tableA | |
WHERE ([tablea].system_id = 450) | |
ORDER BY id DESC"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
[Test] | |
public void QueryWithWhereAndOrderByClause_WorksAsExpected() | |
{ | |
string sql = @"SELECT TOP 10 * | |
FROM tableA a | |
WHERE a.id > 1000 | |
ORDER BY id DESC"; | |
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables); | |
var actual = interceptor.AddTenantChecksToStatement(sql); | |
string expected = @"SELECT TOP 10 * | |
FROM tableA a | |
WHERE a.id > 1000 | |
AND (a.system_id = 450) | |
ORDER BY id DESC"; | |
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment