Skip to content

Instantly share code, notes, and snippets.

@dylanlangston
Last active April 5, 2024 21:27
Show Gist options
  • Save dylanlangston/f233121d699459f6ccfb01fd695585cb to your computer and use it in GitHub Desktop.
Save dylanlangston/f233121d699459f6ccfb01fd695585cb to your computer and use it in GitHub Desktop.
Set Default Browser on Windows 10/11 in C#
using System;
using System.IO;
using System.Diagnostics;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Runtime.InteropServices;
using Microsoft.Win32;
namespace Utilities
{
// Set the default browser programatically
// Based off this powershell script: https://github.com/DanysysTeam/PS-SFTA
// Not documented by microsoft and subject to change without warning...
public static class DefaultBrowser
{
public static void SetTo(Browser browser)
{
if (original == null)
original = new string[][] { ReadProtocolKeys("http"), ReadProtocolKeys("https") };
switch (browser)
{
case Browser.Firefox:
SetPTA("FirefoxHTML-308046B0AF4A39CB", "http");
SetPTA("FirefoxHTML-308046B0AF4A39CB", "https");
break;
case Browser.Chrome:
SetPTA("ChromeHTML", "http");
SetPTA("ChromeHTML", "https");
break;
case Browser.Edge:
SetPTA("MSEdgeHTM", "http");
SetPTA("MSEdgeHTM", "https");
break;
default:
throw new NotImplementedException("Browser not implemented");
}
}
static string[][] original = null;
internal static void ResetToOriginal()
{
if (original == null) return;
SetPTA(original[0][1], "http");
SetPTA(original[1][1], "https");
}
static void SetPTA(string progID, string protocol)
{
var userSID = GetUserSid();
var userExperience = GetUserExperience();
var userDateTime = GetHexDateTime();
Log.Write(userSID);
Log.Write(userExperience);
Log.Write(userDateTime);
var baseInfo = $"{protocol}{userSID}{progID}{userDateTime}{userExperience}".ToLower();
var hash = GetHash(baseInfo);
Log.Write(hash);
WriteProtocolKeys(progID, protocol, hash);
Refresh();
}
static string GetUserSid()
=> System.Security.Principal.WindowsIdentity.GetCurrent().User.Value.ToLowerInvariant();
static string GetUserExperience()
{
var userExperienceSearch = "User Choice set via Windows User Experience";
var user32Path = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.SystemX86), "Shell32.dll");
using var fileStream = File.Open(user32Path, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
var binaryReader = new BinaryReader(fileStream);
var bytesData = binaryReader.ReadBytes(5242880); // 5mb
fileStream.Close();
var dataString = Encoding.Unicode.GetString(bytesData);
var position1 = dataString.IndexOf(userExperienceSearch);
var position2 = dataString.IndexOf("}", position1);
return dataString.Substring(position1, position2 - position1 + 1);
}
static string GetHexDateTime()
{
var time = DateTime.Now;
time = new DateTime(time.Year, time.Month, time.Day, time.Hour, time.Minute, 0);
var fileTime = time.ToFileTime();
var high = fileTime >> 32;
var low = fileTime & 0xFFFFFFFFL;
return (high.ToString("X8") + low.ToString("X8")).ToLowerInvariant();
}
static long GetShiftRight(long value, int count)
{
if ((value & 0x80000000) == 0)
{
return (value >> count);
}
else // Negative number
{
return (value >> count) ^ 0xFFFF0000;
}
}
// I have no idea... lol
class HashMap
{
public long PDATA = 0,
CACHE = 0,
COUNTER = 0,
INDEX = 0,
MD51 = 0,
MD52 = 0,
OUTHASH1 = 0,
OUTHASH2 = 0,
R0 = 0,
R3 = 0;
public Dictionary<int, long>
R1 = new() { { 0, 0 }, { 1, 0 } },
R2 = new() { { 0, 0 }, { 1, 0 } },
R4 = new() { { 0, 0 }, { 1, 0 } },
R5 = new() { { 0, 0 }, { 1, 0 } },
R6 = new() { { 0, 0 }, { 1, 0 } };
}
static string GetHash(string baseInfo)
{
var bytes = Encoding.Unicode.GetBytes(baseInfo);
bytes = bytes.Append((byte)0x00).Append((byte)0x00).ToArray();
var md5 = new System.Security.Cryptography.MD5CryptoServiceProvider().ComputeHash(bytes);
Log.Write(string.Join(", ", md5.Select(b => b.ToString())));
var lengthBase = (baseInfo.Length * 2) + 2;
var length = (((lengthBase & 4) <= 1) ? 1 : 0) + (GetShiftRight(lengthBase,2)) - 1;
var base64Hash = "";
if (length > 1)
{
HashMap map = new();
map.CACHE = 0;
map.OUTHASH1 = 0;
map.PDATA = 0;
map.MD51 = (BitConverter.ToInt32(md5, 0) | 1) + (int)0x69FB0000L;
map.MD52 = (BitConverter.ToInt32(md5, 4) | 1) + (int)0x13DB0000L;
map.INDEX = (int)GetShiftRight(length - 2, 1);
map.COUNTER = map.INDEX + 1;
while (map.COUNTER > 0)
{
map.R0 = BitConverter.ToInt32(BitConverter.GetBytes(BitConverter.ToInt32(bytes, (int)map.PDATA) + map.OUTHASH1));
map.R1[0] = BitConverter.ToInt32(BitConverter.GetBytes(BitConverter.ToInt32(bytes, (int)map.PDATA + 4)));
map.PDATA = map.PDATA + 8;
map.R2[0] = BitConverter.ToInt32(BitConverter.GetBytes((map.R0 * map.MD51) - (0x10FA9605L * GetShiftRight(map.R0, 16))));
map.R2[1] = BitConverter.ToInt32(BitConverter.GetBytes((0x79F8A395L * map.R2[0]) + (0x689B6B9FL * GetShiftRight(map.R2[0], 16))));
map.R3 = BitConverter.ToInt32(BitConverter.GetBytes((0xEA970001L * map.R2[1]) - (0x3C101569L * GetShiftRight(map.R2[1], 16))));
map.R4[0] = BitConverter.ToInt32(BitConverter.GetBytes(map.R3 + map.R1[0]));
map.R5[0] = BitConverter.ToInt32(BitConverter.GetBytes(map.CACHE + map.R3));
map.R6[0] = BitConverter.ToInt32(BitConverter.GetBytes((map.R4[0] * map.MD52) - (0x3CE8EC25L * GetShiftRight(map.R4[0], 16))));
map.R6[1] = BitConverter.ToInt32(BitConverter.GetBytes((0x59C3AF2DL * map.R6[0]) - (0x2232E0F1L * GetShiftRight(map.R6[0], 16))));
map.OUTHASH1 = BitConverter.ToInt32(BitConverter.GetBytes((0x1EC90001L * map.R6[1]) + (0x35BD1EC9L * GetShiftRight(map.R6[1], 16))));
map.OUTHASH2 = BitConverter.ToInt32(BitConverter.GetBytes(map.R5[0] + map.OUTHASH1));
map.CACHE = map.OUTHASH2;
map.COUNTER = map.COUNTER - 1;
}
var outHash = new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
var buffer = BitConverter.GetBytes(map.OUTHASH1).Take(4).ToArray();
buffer.CopyTo(outHash, 0);
buffer = BitConverter.GetBytes(map.OUTHASH2).Take(4).ToArray();
buffer.CopyTo(outHash, 4);
map = new();
map.CACHE = 0;
map.OUTHASH1 = 0;
map.PDATA = 0;
map.MD51 = BitConverter.ToInt32(md5) | 1;
map.MD52 = BitConverter.ToInt32(md5, 4) | 1;
map.INDEX = (int)GetShiftRight(length - 2, 1);
map.COUNTER = map.INDEX + 1;
while (map.COUNTER > 0)
{
map.R0 = BitConverter.ToInt32(BitConverter.GetBytes(BitConverter.ToInt32(bytes, (int)map.PDATA) + map.OUTHASH1));
map.PDATA = map.PDATA + 8;
map.R1[0] = BitConverter.ToInt32(BitConverter.GetBytes(map.R0 * map.MD51));
map.R1[1] = BitConverter.ToInt32(BitConverter.GetBytes((0xB1110000L * map.R1[0]) - (0x30674EEFL * GetShiftRight(map.R1[0], 16))));
map.R2[0] = BitConverter.ToInt32(BitConverter.GetBytes((0x5B9F0000L * map.R1[1]) - (0x78F7A461L * GetShiftRight(map.R1[1], 16))));
map.R2[1] = BitConverter.ToInt32(BitConverter.GetBytes((0x12CEB96DL * GetShiftRight(map.R2[0], 16)) - (0x46930000L * map.R2[0])));
map.R3 = BitConverter.ToInt32(BitConverter.GetBytes((0x1D830000L * map.R2[1]) + (0x257E1D83L * GetShiftRight(map.R2[1], 16))));
map.R4[0] = BitConverter.ToInt32(BitConverter.GetBytes(map.MD52 * (map.R3 + (BitConverter.ToInt32(bytes, (int)map.PDATA - 4)))));
map.R4[1] = BitConverter.ToInt32(BitConverter.GetBytes((0x16F50000L * map.R4[0]) - (0x5D8BE90BL * GetShiftRight(map.R4[0], 16))));
map.R5[0] = BitConverter.ToInt32(BitConverter.GetBytes((0x96FF0000L * map.R4[1]) - (0x2C7C6901L * GetShiftRight(map.R4[1], 16))));
map.R5[1] = BitConverter.ToInt32(BitConverter.GetBytes((0x2B890000L * map.R5[0]) + (0x7C932B89L * GetShiftRight(map.R5[0], 16))));
map.OUTHASH1 = BitConverter.ToInt32(BitConverter.GetBytes((0x9F690000L * map.R5[1]) - (0x405B6097L * GetShiftRight(map.R5[1], 16))));
map.OUTHASH2 = BitConverter.ToInt32(BitConverter.GetBytes(map.OUTHASH1 + map.CACHE + map.R3));
map.CACHE = map.OUTHASH2;
map.COUNTER = map.COUNTER - 1;
}
buffer = BitConverter.GetBytes(map.OUTHASH1).Take(4).ToArray();
buffer.CopyTo(outHash, 8);
buffer = BitConverter.GetBytes(map.OUTHASH2).Take(4).ToArray();
buffer.CopyTo(outHash, 12);
var outHashBase = new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
var hashValue1 = BitConverter.ToInt32(outHash, 8) ^ BitConverter.ToInt32(outHash);
var hashValue2 = BitConverter.ToInt32(outHash, 12) ^ BitConverter.ToInt32(outHash, 4);
buffer = BitConverter.GetBytes(hashValue1);
buffer.CopyTo(outHashBase, 0);
buffer = BitConverter.GetBytes(hashValue2);
buffer.CopyTo(outHashBase, 4);
base64Hash = Convert.ToBase64String(outHashBase);
return base64Hash;
}
else
throw new Exception("Missing base info");
}
static void WriteProtocolKeys(string progId, string protocol, string progHash)
{
var keyPath = $@"HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\{protocol}\UserChoice";
Registry.SetValue(keyPath, "Hash", progHash);
Registry.SetValue(keyPath, "ProgId", progId);
}
static string[] ReadProtocolKeys(string protocol)
{
var keyPath = $@"HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\{protocol}\UserChoice";
return new string[]
{
(string)Registry.GetValue(keyPath, "Hash", null),
(string)Registry.GetValue(keyPath, "ProgId", null)
};
}
static void Refresh() => pInvokes.SHChangeNotify(0x8000000, 0, IntPtr.Zero, IntPtr.Zero);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment