Skip to content

Instantly share code, notes, and snippets.

@MatthewVines
Last active August 8, 2017 17:38
Show Gist options
  • Save MatthewVines/f718a74d411aaa97046a50ef13909da6 to your computer and use it in GitHub Desktop.
Save MatthewVines/f718a74d411aaa97046a50ef13909da6 to your computer and use it in GitHub Desktop.
NPoco Interceptor to inject tenant checks.
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;
using NPoco;
public class SystemCheckQueryInterceptor : IExecutingInterceptor
{
private readonly int _systemId;
private readonly List<string> _nonTenantTables;
public SystemCheckQueryInterceptor(int systemId, List<string> nonTenantTables)
{
_systemId = systemId;
List<string> standardizedTables = new List<string>();
// make sure table names are bracketted.
foreach (string table in nonTenantTables)
{
StringBuilder sb = new StringBuilder();
string[] tokens = table.Split('.');
for (int i = 0; i < tokens.Length; i++)
{
if (!tokens[i].StartsWith("["))
{
sb.Append("[");
}
sb.Append(tokens[i].ToLower());
if (!tokens[i].EndsWith("]"))
{
sb.Append("]");
}
if (tokens[i] != tokens[tokens.Length - 1])
{
sb.Append(".");
}
}
standardizedTables.Add(sb.ToString());
}
_nonTenantTables = standardizedTables;
}
public void OnExecutingCommand(IDatabase database, DbCommand cmd)
{
cmd.CommandText = AddTenantChecksToStatement(cmd.CommandText);
}
public void OnExecutedCommand(IDatabase database, DbCommand cmd)
{
return;
}
internal string AddTenantChecksToStatement(string sql)
{
List<string> parseErrors;
List<TSqlParserToken> queryTokens = TokenizeSql(sql, out parseErrors);
if (parseErrors == null)
{
Dictionary<string, string> tableAliases = GetTableNamesFromQueryString(queryTokens);
string tenantClause = BuildTenantClause(tableAliases);
string newSql = AppendClause(sql, queryTokens, tenantClause, Concatenator.And);
return newSql;
}
// If we had parse errors, there is probably something wrong with the statement.
// The best errors will come back from the db itself, so just pass the statement on as is.
return sql;
}
private List<TSqlParserToken> TokenizeSql(string sql, out List<string> parserErrors)
{
using (System.IO.TextReader tReader = new System.IO.StringReader(sql))
{
var parser = new TSql120Parser(true);
IList<ParseError> errors;
var queryTokens = parser.GetTokenStream(tReader, out errors);
if (errors.Any())
{
parserErrors = errors.Select(e =>$"Error: {e.Number}; Line: {e.Line}; Column: {e.Column}; Offset: {e.Offset}; Message: {e.Message};").ToList();
}
else
{
parserErrors = null;
}
return queryTokens.ToList();
}
}
private Dictionary<string, string> GetTableNamesFromQueryString(List<TSqlParserToken> queryTokens)
{
var output = new Dictionary<string, string>();
var sb = new StringBuilder();
var fromTokenTypes = new[]
{
TSqlTokenType.From,
TSqlTokenType.Join
};
var identifierTokenTypes = new[]
{
TSqlTokenType.Identifier,
TSqlTokenType.QuotedIdentifier
};
for (var i = 0; i < queryTokens.Count; i++)
{
if (fromTokenTypes.Contains(queryTokens[i].TokenType))
{
for (var j = i + 1; j < queryTokens.Count; j++)
{
if (queryTokens[j].TokenType == TSqlTokenType.WhiteSpace)
{
continue;
}
if (identifierTokenTypes.Contains(queryTokens[j].TokenType))
{
sb.Clear();
GetQuotedIdentifier(queryTokens[j], sb);
while (j + 2 < queryTokens.Count
&& queryTokens[j + 1].TokenType == TSqlTokenType.Dot
&& (queryTokens[j + 2].TokenType == TSqlTokenType.Dot || identifierTokenTypes.Contains(queryTokens[j + 2].TokenType)))
{
sb.Append(queryTokens[j + 1].Text.ToLower());
if (queryTokens[j + 2].TokenType == TSqlTokenType.Dot)
{
if (queryTokens[j - 1].TokenType == TSqlTokenType.Dot)
GetQuotedIdentifier(queryTokens[j + 1], sb);
j++;
}
else
{
GetQuotedIdentifier(queryTokens[j + 2], sb);
j += 2;
}
}
string tableName = sb.ToString();
string alias = tableName;
if (j + 2 < queryTokens.Count
&& identifierTokenTypes.Contains(queryTokens[j].TokenType)
&& queryTokens[j + 1].TokenType == TSqlTokenType.WhiteSpace
&& identifierTokenTypes.Contains(queryTokens[j + 2].TokenType))
{
alias = queryTokens[j + 2].Text;
}
output.Add(alias, tableName);
}
break;
}
}
}
return output;
}
private void GetQuotedIdentifier(TSqlParserToken token, StringBuilder sb)
{
switch (token.TokenType)
{
case TSqlTokenType.Identifier:
sb.Append('[').Append(token.Text.ToLower()).Append(']');
break;
case TSqlTokenType.QuotedIdentifier:
case TSqlTokenType.Dot:
sb.Append(token.Text.ToLower());
break;
default: throw new ArgumentException("Error: expected TokenType of token should be TSqlTokenType.Dot, TSqlTokenType.Identifier, or TSqlTokenType.QuotedIdentifier");
}
}
private bool IsTenantTable(string tableName)
{
foreach (string nonTenantTable in _nonTenantTables)
{
if (nonTenantTable.Contains(tableName.Split(new[] {'.'}).Last()))
{
return false;
}
}
return true;
}
private string BuildTenantClause(Dictionary<string, string> tableAliases)
{
StringBuilder sb = new StringBuilder();
bool firstTenantTable = true;
foreach (KeyValuePair<string, string> kvp in tableAliases)
{
if (IsTenantTable(kvp.Value))
{
if (firstTenantTable)
{
sb.Append("(");
firstTenantTable = false;
}
else
{
sb.Append(" AND ");
}
sb.Append($"{kvp.Key}.system_id = {_systemId}");
}
}
if (sb.Length > 0)
{
sb.Append(")");
}
return sb.ToString();
}
private enum Concatenator
{
And,
Or
}
private string AppendClause(string sql, List<TSqlParserToken> queryTokens, string newClause, Concatenator concatenator)
{
int indexOfWhere = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Where);
int indexOfGroup = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Group);
int indexOfOrder = queryTokens.FindIndex(x => x.TokenType == TSqlTokenType.Order);
bool hasWhere = indexOfWhere >= 0;
bool hasGroup = indexOfGroup >= 0;
bool hasOrder = indexOfOrder >= 0;
int orderOffset = (hasOrder) ? queryTokens[indexOfOrder].Offset : Int32.MinValue;
int groupOffset = (hasGroup) ? queryTokens[indexOfGroup].Offset : Int32.MinValue;
if (!hasWhere && !hasGroup && !hasOrder)
{
return sql + $" WHERE {newClause}";
}
if (!hasWhere && hasGroup)
{
return sql.Insert(groupOffset, $" WHERE {newClause} ");
}
if (!hasWhere && !hasGroup && hasOrder)
{
return sql.Insert(orderOffset, $" WHERE {newClause} ");
}
if (hasWhere && !hasGroup && !hasOrder)
{
return sql + $" {concatenator.ToString().ToUpper()} {newClause}";
}
if (hasWhere && hasGroup)
{
return sql.Insert(groupOffset, $" {concatenator.ToString().ToUpper()} {newClause} ");
}
if (hasWhere && !hasGroup && hasOrder)
{
return sql.Insert(orderOffset, $" {concatenator.ToString().ToUpper()} {newClause} ");
}
return null;
}
}
using System;
using System.Collections.Generic;
using NUnit.Framework;
using xray_svc;
[TestFixture]
internal class SystemCheckQueryInterceptorTests
{
private int _systemId;
private List<string> _nonTenantTables;
[SetUp]
public void Setup()
{
_systemId = 450;
_nonTenantTables = new List<string>();
}
//Used to make the assertions more resilient
private static string RemoveWhitespace(string s)
{
return s.Replace(" ", "").Replace(Environment.NewLine, "");
}
[Test]
public void QueryWithoutWhereClause_AddWhereClauseWithTenantClause()
{
string sql = @"SELECT TOP 10 * FROM tableA";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 * FROM tableA WHERE ([tablea].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithWhereClause_AddTenantClauseWithAnd()
{
string sql = @"SELECT TOP 10 * FROM tableA WHERE x = y";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 * FROM tableA WHERE x = y AND ([tablea].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithJoins_AddTenantClauseForEachTable()
{
string sql = @"SELECT TOP 10 *
FROM tableA
JOIN tableB on tableA.x = tableB.x
JOIN tableC on tableA.x = tableC.x
WHERE tablaA.x = 'y'";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM tableA
JOIN tableB on tableA.x = tableB.x
JOIN tableC on tableA.x = tableC.x
WHERE tablaA.x = 'y'
AND ([tablea].system_id = 450
AND [tableb].system_id = 450
AND [tablec].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void FullyQualifiedTables_WorkAsExpected()
{
string sql = @"SELECT TOP 10 *
FROM [db].[sch].[tableA]
JOIN xxx.tableB on tableA.x = xxx.tableB.x
JOIN tableC on db.sch.tableA.x = tableC.x
WHERE [db].[sch].[tableA].x = 'y'";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM [db].[sch].[tableA]
JOIN xxx.tableB on tableA.x = xxx.tableB.x
JOIN tableC on db.sch.tableA.x = tableC.x
WHERE [db].[sch].[tableA].x = 'y'
AND ([db].[sch].[tablea].system_id = 450
AND [xxx].[tableb].system_id = 450
AND [tablec].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithAliases_AliasesAreUsedRatherThanTableNames()
{
string sql = @"SELECT TOP 10 *
FROM [db].[sch].[tableA] A
JOIN xxx.tableB on A.x = xxx.tableB.x
JOIN tableA A2 on A.x = A2.x
WHERE A.x = 'y' AND A2.x = 'yy'";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM [db].[sch].[tableA] A
JOIN xxx.tableB on A.x = xxx.tableB.x
JOIN tableA A2 on A.x = A2.x
WHERE A.x = 'y' AND A2.x = 'yy'
AND (A.system_id = 450
AND [xxx].[tableb].system_id = 450
AND A2.system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithNonTenantTables_OnlyAddClausesForTenantTables()
{
_nonTenantTables.Add("tableB");
_nonTenantTables.Add("tableD");
string sql = @"SELECT TOP 10 *
FROM tableA
JOIN tableB on tableA.x = tableB.x
JOIN tableC on tableB.x = tableC.x
JOIN tableD on tableC.x = tableD.x
WHERE tablaA.x = 'y'";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM tableA
JOIN tableB on tableA.x = tableB.x
JOIN tableC on tableB.x = tableC.x
JOIN tableD on tableC.x = tableD.x
WHERE tablaA.x = 'y'
AND ([tablea].system_id = 450
AND [tablec].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithSchemaDefinedNonTenantTables_WorksAsExpected()
{
_nonTenantTables.Add("xxx.tableB");
_nonTenantTables.Add("db.xxx.tableD");
string sql = @"SELECT TOP 10 *
FROM tableA
JOIN xxx.tableB on tableA.x = xxx.tableB.x
JOIN tableC on tableB.x = tableC.x
JOIN db.xxx.tableD on tableC.x = db.xxx.tableD.x
WHERE tablaA.x = 'y'";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM tableA
JOIN xxx.tableB on tableA.x = xxx.tableB.x
JOIN tableC on tableB.x = tableC.x
JOIN db.xxx.tableD on tableC.x = db.xxx.tableD.x
WHERE tablaA.x = 'y'
AND ([tablea].system_id = 450
AND [tablec].system_id = 450)";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithoutWhereClauseWithGroupByClause_WorksAsExpected()
{
string sql = @"SELECT product_id, AVG(sale_price)
FROM Sales
GROUP BY product_id";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT product_id, AVG(sale_price)
FROM Sales
WHERE ([sales].system_id = 450)
GROUP BY product_id";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithoutWhereClauseWithOrderByClause_WorksAsExpected()
{
string sql = @"SELECT TOP 10 *
FROM tableA
ORDER BY id DESC";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM tableA
WHERE ([tablea].system_id = 450)
ORDER BY id DESC";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
[Test]
public void QueryWithWhereAndOrderByClause_WorksAsExpected()
{
string sql = @"SELECT TOP 10 *
FROM tableA a
WHERE a.id > 1000
ORDER BY id DESC";
var interceptor = new SystemCheckQueryInterceptor(_systemId, _nonTenantTables);
var actual = interceptor.AddTenantChecksToStatement(sql);
string expected = @"SELECT TOP 10 *
FROM tableA a
WHERE a.id > 1000
AND (a.system_id = 450)
ORDER BY id DESC";
Assert.AreEqual(RemoveWhitespace(expected), RemoveWhitespace(actual));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment