Skip to content

Instantly share code, notes, and snippets.

@tombatron
Created October 11, 2019 19:48
Show Gist options
  • Save tombatron/0bb50f2c07701f12acfa7eab9084069c to your computer and use it in GitHub Desktop.
Save tombatron/0bb50f2c07701f12acfa7eab9084069c to your computer and use it in GitHub Desktop.
using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
namespace HtmlRemover
{
class Program
{
static string SampleHtml = @"<link rel=stylesheet href='insertUrlHere' ><style> h2 {color:red;}</style><div>This is a div</div><div> </div><script src = 'InsertUrl' > And stuff here</script><div></div ><p style='color: blue'>Text in p tag</p>";
static void Main(string[] args)
{
var saniHtml = Sanitizer.Sanitize(SampleHtml);
Console.WriteLine("Input String");
Console.WriteLine(SampleHtml);
Console.WriteLine("---------------------------------------");
Console.WriteLine("Output String");
Console.WriteLine(saniHtml);
Console.WriteLine("---------------------------------------");
Console.WriteLine("Expected Output");
Console.WriteLine("<div>This is a div</div><p data-style='color: blue'>Text in p tag</p>");
}
}
public interface IHtmlElement
{
(int fastForwardLength, char[] replacement, bool replaced) FastForward(ReadOnlySpan<char> html);
bool IsMatch(ReadOnlySpan<char> html);
}
public abstract class BaseHtmlElement : IHtmlElement
{
private const char NonPrintedSpace = (char)32;
protected abstract char[] TagSymbol { get; }
public virtual bool RemoveEntireElement { get; set; } = false;
public virtual bool RemoveEmptyElement { get; set; } = false;
public char[] FindAttribute { get; set; }
public char[] ReplaceAttribute { get; set; }
public virtual (int fastForwardLength, char[] replacement, bool replaced) FastForward(ReadOnlySpan<char> html)
{
if (RemoveEntireElement)
{
return (FindClosingTag(html), default, true);
}
else if (RemoveEmptyElement)
{
var (closing, _) = IsClosingTag(html);
if (closing)
{
// If we're at a closing tag let's bail.
return (default, default, false);
}
var (_, openTagLength) = IsOpenTag(html);
var currentIndex = openTagLength;
while(html[currentIndex] != '<')
{
if (html[currentIndex] != NonPrintedSpace && html[currentIndex] != ' ')
{
if (FindAttribute != null && ReplaceAttribute != null)
{
var isolatedTag = html.Slice(0, FindClosingTag(html));
var replacementLocation = isolatedTag.IndexOf(FindAttribute);
if (isolatedTag.IndexOf(FindAttribute) > 0)
{
var lengthDiff = ReplaceAttribute.Length - FindAttribute.Length;
var overallLength = isolatedTag.Length + lengthDiff;
var replacementSpan = new Span<char>(new char[overallLength]);
isolatedTag.Slice(0, replacementLocation).CopyTo(replacementSpan.Slice(0, replacementLocation));
ReplaceAttribute.CopyTo(replacementSpan.Slice(replacementLocation));
isolatedTag.Slice(replacementLocation + FindAttribute.Length).CopyTo(replacementSpan.Slice(replacementLocation + ReplaceAttribute.Length));
return (overallLength, replacementSpan.ToArray(), false);
}
}
// NOT EMPTY
return (default, default, false);
}
currentIndex++;
}
// IT WAS EMPTY
return (FindClosingTag(html), default, true);
}
else
{
bool match;
int length;
(match, length) = IsOpenTag(html);
if (match)
{
return (length, default, true);
}
(_, length) = IsClosingTag(html);
return (length, default, true);
}
}
protected int FindClosingTag(ReadOnlySpan<char> html)
{
var currentBaseIndex = html.IndexOf('<');
while (true)
{
var (match, length) = IsClosingTag(html.Slice(currentBaseIndex));
if (match)
{
return currentBaseIndex + length;
}
else
{
if (currentBaseIndex + 1 == html.Length)
{
return default;
}
currentBaseIndex = html.Slice(currentBaseIndex + 1).IndexOf('<') + 1;
}
}
}
public (bool isMatch, int length) IsOpenTag(ReadOnlySpan<char> html)
{
int length = 1;
bool openBracketFound = html[0] == '<';
bool tagSymbolFound = false;
if (!openBracketFound)
{
return (false, default);
}
for (var i = 1; i < html.Length; i++)
{
length++;
if (tagSymbolFound)
{
if (html[i] == '>')
{
return (true, length);
}
}
else
{
if (TagSymbol.Length + i > html.Length)
{
return (false, default);
}
if (html.Slice(i, TagSymbol.Length).SequenceEqual(TagSymbol))
{
tagSymbolFound = true;
var nextChar = html[i + TagSymbol.Length];
if (nextChar != NonPrintedSpace && nextChar != ' ' && nextChar != '>')
{
return (false, default);
}
else
{
i += TagSymbol.Length - 1;
length += TagSymbol.Length - 1;
}
continue;
}
else
{
if (html[i] != ' ' || html[i] != NonPrintedSpace)
{
return (false, default);
}
}
}
}
return (false, default);
}
public (bool isMatch, int length) IsClosingTag(ReadOnlySpan<char> html)
{
int length = 0;
bool openBracketFound = html[0] == '<';
bool endSlashFound = false;
bool tagSymbolFound = false;
if (!openBracketFound)
{
return (false, default);
}
for (var i = 1; i < html.Length; i++)
{
length++;
if (endSlashFound)
{
if (tagSymbolFound)
{
if (html[i] == NonPrintedSpace || html[i] == ' ')
{
// This is ok.. let's keep looking...
continue;
}
if (html[i] == '>')
{
return (true, length);
}
// Something wrong here...This probably isn't a closing tag.
return (false, default);
}
else
{
if (i + TagSymbol.Length > html.Length)
{
return (false, default);
}
if (html[i] == '>')
{
return (false, default);
}
if (html.Slice(i, TagSymbol.Length).SequenceEqual(TagSymbol))
{
length += TagSymbol.Length;
// We've found the tag symbol so we need to move the index forward,
// minus 1 because we're currently on the first character of the tag
// symbol.
i += TagSymbol.Length - 1;
tagSymbolFound = true;
}
}
}
else
{
if (html[i] != ' ' && html[i] != '/')
{
return (false, default);
}
if (html[i] == '/')
{
endSlashFound = true;
}
}
}
return (false, default);
}
public bool IsMatch(ReadOnlySpan<char> html)
{
var (isOpen, _) = IsOpenTag(html);
if (!isOpen)
{
var (isClose, _) = IsClosingTag(html);
return isClose;
}
return isOpen;
}
}
public class HtmlElementCollection : IEnumerable<IHtmlElement>
{
public IEnumerator<IHtmlElement> GetEnumerator() => _htmlElements.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
private static IEnumerable<IHtmlElement> _htmlElements = new IHtmlElement[]
{
new HtmlLink(),
new HtmlStyle() { RemoveEntireElement = true },
new HtmlDiv() { RemoveEmptyElement = true },
new HtmlScript() { RemoveEntireElement = true },
new HtmlParagraph()
{
RemoveEmptyElement = true,
FindAttribute = "style".ToCharArray(),
ReplaceAttribute = "data-style".ToCharArray()
}
};
public (int fastForwardLength, char[] replacement, bool replaced) HandleTag(ReadOnlySpan<char> html, int currentPosition)
{
foreach (var element in this)
{
if (element.IsMatch(html))
{
var (length, replacement, replaced) = element.FastForward(html);
return (currentPosition + length, replacement, replaced);
}
}
return (currentPosition, default, false);
}
}
public class HtmlComment : IHtmlElement
{
public (int fastForwardLength, char[] replacement, bool replaced) FastForward(ReadOnlySpan<char> html) =>
(html.IndexOf(EndSequence) + EndSequence.Length, default, true);
public bool IsMatch(ReadOnlySpan<char> html) => html.Slice(0, 4).SequenceEqual(BeginSequence);
private char[] BeginSequence => "<!--".ToCharArray();
private char[] EndSequence => "-->".ToCharArray();
}
public class HtmlHtml : BaseHtmlElement
{
protected override char[] TagSymbol => "html".ToCharArray();
}
public class HtmlLink : BaseHtmlElement
{
protected override char[] TagSymbol => "link".ToCharArray();
}
public class HtmlBody : BaseHtmlElement
{
protected override char[] TagSymbol => "body".ToCharArray();
}
public class HtmlHead : BaseHtmlElement
{
protected override char[] TagSymbol => "head".ToCharArray();
}
public class HtmlTitle : BaseHtmlElement
{
protected override char[] TagSymbol => "title".ToCharArray();
}
public class HtmlStyle : BaseHtmlElement
{
protected override char[] TagSymbol => "style".ToCharArray();
}
public class HtmlDiv : BaseHtmlElement
{
protected override char[] TagSymbol => "div".ToCharArray();
}
public class HtmlScript : BaseHtmlElement
{
protected override char[] TagSymbol => "script".ToCharArray();
}
public class HtmlParagraph : BaseHtmlElement
{
protected override char[] TagSymbol => "p".ToCharArray();
}
public static class Sanitizer
{
private static ArrayPool<char> CharArrayPool = ArrayPool<char>.Shared;
private static HtmlElementCollection HtmlElements = new HtmlElementCollection();
public static string Sanitize(ReadOnlySpan<char> html)
{
var pipe = new Pipe();
var writer = pipe.Writer;
var reader = pipe.Reader;
int bufferPosition = 0;
char[] buffer = default;
void FlushBuffer()
{
var workspace = writer.GetMemory(bufferPosition);
Encoding.UTF8.GetBytes(buffer[0..bufferPosition]).CopyTo(workspace);
writer.Advance(bufferPosition);
CharArrayPool.Return(buffer, true);
bufferPosition = 0;
};
void PopulateBuffer(char[] value)
{
var workspace = writer.GetMemory(value.Length);
Encoding.UTF8.GetBytes(value).CopyTo(workspace);
writer.Advance(value.Length);
CharArrayPool.Return(buffer, true);
bufferPosition = 0;
}
for (var i = 0; i < html.Length; i++)
{
if (bufferPosition == 100)
{
FlushBuffer();
}
if (bufferPosition == 0)
{
buffer = CharArrayPool.Rent(100);
}
if (html[i] == '<')
{
FlushBuffer();
char[] replacement;
bool replaced;
(i, replacement, replaced) = HtmlElements.HandleTag(html.Slice(i), i);
if (replacement != default)
{
PopulateBuffer(replacement);
continue;
}
if (replaced && html[i] == '<')
{
i--;
continue;
}
}
if (i < html.Length)
{
buffer[bufferPosition++] = html[i];
}
}
if (bufferPosition > 0)
{
FlushBuffer();
}
writer.Complete();
reader.TryRead(out var result);
return Encoding.UTF8.GetString(result.Buffer.ToArray());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment