Skip to content

Instantly share code, notes, and snippets.

@yar-shukan
Last active September 3, 2015 23:32
Show Gist options
  • Save yar-shukan/40ca19a26efa88237318 to your computer and use it in GitHub Desktop.
Save yar-shukan/40ca19a26efa88237318 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
namespace SourceCodeSandbox
{
public static class SanboxExtensions
{
private static readonly HashSet<Type> AllowedDefaultTypes = new HashSet<Type>
{
typeof(object), typeof(Enumerable), typeof(Guid), typeof(string),
typeof(DateTime), typeof(DateTimeOffset), typeof(HashSet<>), typeof(List<>),
typeof(ISet<>), typeof(IQueryable<>), typeof(IQueryable)
};
public static object ExecuteOrThrow(string userCodeWithoutParams)
{
HashSet<Type> dontCare1;
Diagnostic[] dontCare2;
object executeResult;
Sandbox.TryExecuteCode(userCodeWithoutParams, null, AllowedDefaultTypes, out executeResult, out dontCare1,
out dontCare2);
return executeResult;
}
public static bool TryExecuteCode(string userCodeWithoutParams, out object executeResult)
{
HashSet<Type> dontCare1;
Diagnostic[] dontCare2;
return Sandbox.TryExecuteCode(userCodeWithoutParams, null, AllowedDefaultTypes, out executeResult, out dontCare1,
out dontCare2);
}
public static bool TryExecuteCode(string userCodeWithoutParams, out object executeResult, out HashSet<Type> restrictedTypes)
{
Diagnostic[] dontCare2;
return Sandbox.TryExecuteCode(userCodeWithoutParams, null, AllowedDefaultTypes, out executeResult, out restrictedTypes,
out dontCare2);
}
}
}
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
namespace SourceCodeSandbox
{
public static class Sandbox
{
private static string sandbox =
@"using System;
using System.Collections.Generic;
using System.Linq;
public static class Sandbox
{{
public static object Execute()
{{
{0}
}}
}}";
public static bool TryExecuteCode(string userCodeToExecute,
object[] userCodeArgs,
IEnumerable<Type> allowedTypesToUse,
out object executeResult,
out HashSet<Type> restrictedTypes,
out Diagnostic[] compileFailures)
{
string sourceCode = GetSourceCodeOrThrow(userCodeToExecute, userCodeArgs);
var allowedTypesToUseSet = new HashSet<Type>(allowedTypesToUse);
if (!allowedTypesToUseSet.Any())
{
throw new ArgumentOutOfRangeException("allowedTypesToUse");
}
var allRestrictedTypesAndCompilation = GetAllRestrictedTypesAndCompilation(sourceCode, allowedTypesToUseSet);
restrictedTypes = allRestrictedTypesAndCompilation.Item1;
if (restrictedTypes.Any())
{
executeResult = null;
compileFailures = null;
return false;
}
Tuple<Assembly, Diagnostic[]> assemblyWithErrorDiagnosticsIfAny =
ToAssemblyWithErrorDiagnosticsIfAny(allRestrictedTypesAndCompilation.Item2);
if (assemblyWithErrorDiagnosticsIfAny.Item1 == null ||
assemblyWithErrorDiagnosticsIfAny.Item2.Any(x => x.IsWarningAsError))
{
executeResult = null;
compileFailures = assemblyWithErrorDiagnosticsIfAny.Item2;
return false;
}
var type = assemblyWithErrorDiagnosticsIfAny.Item1.GetType("Sandbox");
try
{
executeResult = type.GetMethod("Execute").Invoke(null, null);
compileFailures = null;
return true;
}
catch (Exception executeException)
{
executeResult = null;
compileFailures = null;
return false;
}
}
private static Tuple<Assembly, Diagnostic[]> ToAssemblyWithErrorDiagnosticsIfAny(CSharpCompilation compilation)
{
using (var ms = new MemoryStream())
{
EmitResult result = compilation.Emit(ms);
if (!result.Success)
{
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(diagnostic =>
diagnostic.IsWarningAsError ||
diagnostic.Severity == DiagnosticSeverity.Error);
return Tuple.Create((Assembly)null, failures.ToArray());
}
ms.Seek(0, SeekOrigin.Begin);
Assembly assembly = Assembly.Load(ms.ToArray());
return Tuple.Create(assembly, new Diagnostic[0]);
}
}
private static string GetSourceCodeOrThrow(string userCodeToExecute, object[] codeArgs)
{
if (string.IsNullOrWhiteSpace(userCodeToExecute))
{
throw new ArgumentException("userSourceCode");
}
string codeWithInlinedParams = (codeArgs != null && codeArgs.Any()) ?
string.Format(userCodeToExecute, codeArgs) : userCodeToExecute;
string sourceCode = string.Format(sandbox, codeWithInlinedParams);
return sourceCode;
}
public static Tuple<HashSet<Type>, CSharpCompilation> GetAllRestrictedTypesAndCompilation(string sourceCode, HashSet<Type> allowedToUseTypes)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(sourceCode);
SyntaxNode sourceCodeTreeRoot = syntaxTree.GetRoot();
CSharpCompilation compilation = CSharpCompilation.Create("SandboxAssembly")
.AddSyntaxTrees(syntaxTree)
.AddReferences(new[] {typeof (object)}.Concat(allowedToUseTypes)
.Distinct()
.Select(type => type.Assembly)
.Select(x => MetadataReference.CreateFromFile(x.Location)))
.WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree, false);
IEnumerable<Type> typesCantUse = sourceCodeTreeRoot.DescendantNodes()
.Select(expression => semanticModel.GetTypeInfo(expression).Type)
.Where(typeSymbol => typeSymbol != null && !(typeSymbol is IErrorTypeSymbol))
.Select(typeSymbol =>
new
{
typeSymbol,
symbolDisplayFormat =
new SymbolDisplayFormat(
typeQualificationStyle:
SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces)
})
.Select(@t => @t.typeSymbol.ToDisplayString(@t.symbolDisplayFormat))
.Select(x => AppDomain.CurrentDomain.GetAssemblies()
.Select(assembly => assembly.GetType(x)).FirstOrDefault(type => type != null))
.Where(type => type != null)
.Select(type => type.IsArray ? type.GetElementType() : type)
.Where(type => !(type.IsPrimitive || allowedToUseTypes.Contains(type)));
return Tuple.Create(new HashSet<Type>(typesCantUse), compilation);
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using NUnit.Framework;
namespace SourceCodeSandbox
{
public class SandboxTests
{
[TestCase(null)]
[TestCase("")]
public void TryExecuteCode_NullOrEmptySourceCode_Exception(string nullOrEmpty)
{
object result;
Assert.Throws<ArgumentException>(
() => SanboxExtensions.TryExecuteCode(nullOrEmpty, out result));
}
[Test]
public void TryExecuteCode_ValidSourceCode_CorrectResult_()
{
var code = @"var ints = new[] { 1, 2, 3, 4, 5 };
return ints.Where(x => x > 3).Take(1).ToArray();";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsTrue(executed, String.Join(", ", restrictedTypes));
Assert.IsNotNull(result);
Assert.IsInstanceOf<IEnumerable<int>>(result);
Assert.AreEqual(4, ((IEnumerable<int>)result).Single());
}
[Test]
public void TryExecuteCode_ValidSourceCode_CorrectResult___()
{
var code = @"var ints = new List<int> { 1, 2, int.Parse(""3""), 4, 5 };
return ints.Where(x => x > 3).Take(1).ToArray();";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsTrue(executed, String.Join(", ", restrictedTypes));
Assert.IsNotNull(result);
Assert.IsInstanceOf<IEnumerable<int>>(result);
Assert.AreEqual(4, ((IEnumerable<int>)result).Single());
}
[Test]
public void TryExecuteCode_ValidSourceCode_CorrectResult__()
{
var code = @"var value = new Guid();
return value;;";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsTrue(executed, String.Join(", ", restrictedTypes));
Assert.IsNotNull(result);
Assert.IsInstanceOf<Guid>(result);
}
[Test]
public void TryExecuteCode_ValidSourceCode_CorrectResult____()
{
var code = @" var dateTimes = new HashSet<DateTime>
{
DateTime.MinValue, DateTime.MaxValue, DateTime.UtcNow
};
return dateTimes.AsQueryable().Where(x => x < DateTime.Today).ToArray();";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsTrue(executed, String.Join(", ", restrictedTypes));
Assert.IsNotNull(result);
Assert.IsInstanceOf<IEnumerable<DateTime>>(result);
}
[Test]
public void TryExecuteCode_NotSecureSourceCode_False()
{
var code = @"var dateTimes = new HashSet<DateTime>
{
DateTime.MinValue, DateTime.MaxValue, DateTime.UtcNow
};
return dateTimes.Select(x =>
{
var b = x < DateTime.Today;
System.Diagnostics.Debug.Fail(""message"");
return b;
}).ToArray();";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsFalse(executed);
Assert.IsNull(result);
CollectionAssert.Contains(restrictedTypes, typeof(System.Diagnostics.Debug));
}
[Test]
public void TryExecuteCode_NotSecureSourceCode_False__()
{
var code = @" Thread.Sleep(int.MaxValue);
return ""haha"";";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsFalse(executed);
Assert.IsNull(result);
}
[Test]
[Timeout(1000)]
public void TryExecuteCode_SafeCode_ButWithDenialOfServiceAttack_False()
{
var code = @"int i = 0;
while (true) { i++; }
return ""haha"";";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsFalse(executed);
Assert.IsNull(result);
}
[Test]
public void TryExecuteCode_NotSecureSourceCode_False_()
{
var code = @"var currentDirectory = Environment.CurrentDirectory;
return ""fool"";";
object result;
HashSet<Type> restrictedTypes;
var executed = SanboxExtensions.TryExecuteCode(code, out result, out restrictedTypes);
Assert.IsFalse(executed);
Assert.IsNull(result);
}
private object Execute1()
{
var ints = new[] { 1, 2, 3, 4, 5 };
return ints.Where(x => x > 3).Take(1).ToArray();
}
private object Execute2()
{
var ints = new List<int> { 1, 2, int.Parse("3"), 4, 5 };
return ints.Where(x => x > 3).Take(1).ToArray();
}
private object Execute3()
{
var value = new Guid();
return value;
}
private object Execute4()
{
var dateTimes = new HashSet<DateTime>
{
DateTime.MinValue, DateTime.MaxValue, DateTime.UtcNow
};
return dateTimes.AsQueryable().Where(x => x < DateTime.Today).ToArray();
}
private object Fail1()
{
var dateTimes = new HashSet<DateTime>
{
DateTime.MinValue, DateTime.MaxValue, DateTime.UtcNow
};
return dateTimes.Select(x =>
{
var b = x < DateTime.Today;
System.Diagnostics.Debug.Fail("message");
return b;
}).ToArray();
}
private object Fail2()
{
Thread.Sleep(int.MaxValue);
return "haha";
}
private object Fail3()
{
int i = 0;
while (true) { i++; }
return "haha";
}
private object Fail4()
{
var currentDirectory = Environment.CurrentDirectory;
return "fool";
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment