Skip to content

Instantly share code, notes, and snippets.

@meitinger
Last active October 2, 2023 20:27
Show Gist options
  • Save meitinger/8cea0d1902dc86e76d3f31d20d9cadfc to your computer and use it in GitHub Desktop.
Save meitinger/8cea0d1902dc86e76d3f31d20d9cadfc to your computer and use it in GitHub Desktop.
Utility that rotates Windows Defender Firewall log files.
/* Copyright (C) 2023, Manuel Meitinger
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#nullable enable
using Microsoft.Win32;
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.AccessControl;
using System.Security.Principal;
using System.ServiceProcess;
using System.Text;
using static System.FormattableString;
[assembly: AssemblyTitle("AufBauWerk FirewallLogRotate")]
[assembly: AssemblyDescription("Rotates Windows Defender Firewall log files.")]
[assembly: AssemblyCompany("AufBauWerk - Unternehmen für junge Menschen")]
[assembly: AssemblyCopyright("Copyright © 2023 by Manuel Meitinger")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: ComVisible(false)]
namespace System.ServiceProcess
{
public static class ServiceHelper
{
private const string AdvapiDll = "advapi32.dll";
[DllImport(AdvapiDll, CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
private static extern bool ChangeServiceConfig2W(IntPtr service, int infoLevel, ServiceFailureActions info);
[DllImport(AdvapiDll, CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
private static extern bool CloseServiceHandle(IntPtr handle);
[DllImport(AdvapiDll, CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
private static extern IntPtr CreateServiceW(IntPtr manager, string serviceName, string? displayName, int desiredAccess, ServiceType serviceType, ServiceStartMode startType, int errorControl, string binaryPathName, IntPtr loadOrderGroup, IntPtr tagId, string? dependencies, string? servicesStartName, string? password);
[DllImport(AdvapiDll, CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
private static extern bool DeleteService(IntPtr service);
private static readonly MethodInfo CreateSafeWin32ExceptionMethod = typeof(ServiceController).GetMethod(nameof(CreateSafeWin32Exception), BindingFlags.NonPublic | BindingFlags.Static) ?? throw new MissingMethodException(nameof(ServiceController), nameof(CreateSafeWin32Exception));
private static readonly MethodInfo GetDataBaseHandleWithAccessMethod = typeof(ServiceController).GetMethod(nameof(GetDataBaseHandleWithAccess), BindingFlags.NonPublic | BindingFlags.Static) ?? throw new MissingMethodException(nameof(ServiceController), nameof(GetDataBaseHandleWithAccess));
private static readonly MethodInfo GetServiceHandleMethod = typeof(ServiceController).GetMethod(nameof(GetServiceHandle), BindingFlags.NonPublic | BindingFlags.Instance) ?? throw new MissingMethodException(nameof(ServiceController), nameof(GetServiceHandle));
private static Exception BuildException([CallerMemberName] string memberName = "") => new InvalidOperationException($"{memberName} failed.", CreateSafeWin32Exception());
private static Win32Exception CreateSafeWin32Exception()
{
try { return (Win32Exception)CreateSafeWin32ExceptionMethod.Invoke(null, new object[0]); }
catch (TargetInvocationException ex) { throw ex.InnerException; }
}
private static IntPtr GetDataBaseHandleWithAccess(string machineName, int serviceControlManaqerAccess)
{
try { return (IntPtr)GetDataBaseHandleWithAccessMethod.Invoke(null, new object[] { machineName, serviceControlManaqerAccess }); }
catch (TargetInvocationException ex) { throw ex.InnerException; }
}
private static IntPtr GetServiceHandle(this ServiceController controller, int desiredAccess)
{
try { return (IntPtr)GetServiceHandleMethod.Invoke(controller, new object[] { desiredAccess }); }
catch (TargetInvocationException ex) { throw ex.InnerException; }
}
public static ServiceController Create(string serviceName, ServiceCreateParameters parameters) => Create(serviceName, ".", parameters);
public static ServiceController Create(string serviceName, string machineName, ServiceCreateParameters parameters)
{
var manager = GetDataBaseHandleWithAccess(machineName, 0x0002);
try
{
var service = CreateServiceW
(
manager,
serviceName,
parameters.DisplayName,
desiredAccess: 0,
parameters.ServiceType,
parameters.StartType,
errorControl: 1,
parameters.BinPath,
IntPtr.Zero,
IntPtr.Zero,
parameters.DependentServices?.Length is null or 0 ? null : string.Join("\0", parameters.DependentServices.Append(string.Empty)),
parameters.UserName,
parameters.Password
);
if (service == IntPtr.Zero)
{
throw BuildException();
}
CloseServiceHandle(service);
}
finally { CloseServiceHandle(manager); }
return new(serviceName, machineName);
}
public static void Delete(this ServiceController controller)
{
var service = controller.GetServiceHandle(0x10000);
try
{
if (!DeleteService(service))
{
throw BuildException();
}
}
finally { CloseServiceHandle(service); }
controller.Refresh();
}
public static bool Exists(this ServiceController controller)
{
try { CloseServiceHandle(controller.GetServiceHandle(0x0001)); }
catch (InvalidOperationException ex) when (ex.InnerException is Win32Exception winexe && winexe.NativeErrorCode is 1060) { return false; }
return true;
}
public static void SetFailureActions(this ServiceController controller, ServiceFailureActions actions)
{
var service = controller.GetServiceHandle(0x0012);
try
{
if (!ChangeServiceConfig2W(service, 2, actions))
{
throw BuildException();
}
}
finally { CloseServiceHandle(service); }
controller.Refresh();
}
}
public class ServiceCreateParameters
{
public ServiceCreateParameters(string binPath) => BinPath = binPath;
public string BinPath { get; set; }
public string[]? DependentServices { get; set; }
public string? DisplayName { get; set; }
public ServiceType ServiceType { get; set; }
public ServiceStartMode StartType { get; set; }
public string? UserName { get; set; }
public string? Password { get; set; }
}
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
public struct ServiceFailureAction
{
public ServiceFailureActionType Type;
public int Delay;
}
public enum ServiceFailureActionType
{
None = 0,
Restart = 1,
Reboot = 2,
RunCommand = 3,
}
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
public class ServiceFailureActions
{
public ServiceFailureActions()
{
ActionCount = 0;
ActionsArray = IntPtr.Zero;
Handle = IntPtr.Zero;
}
public ServiceFailureActions(params ServiceFailureAction[] actions)
{
ActionCount = actions.Length;
var handle = GCHandle.Alloc(actions, GCHandleType.Pinned);
ActionsArray = handle.AddrOfPinnedObject();
Handle = GCHandle.ToIntPtr(handle);
}
~ServiceFailureActions()
{
if (Handle != IntPtr.Zero)
{
GCHandle.FromIntPtr(Handle).Free();
}
}
public int ResetPeriod;
public string? RebootMessage;
public string? Command;
private readonly int ActionCount;
private readonly IntPtr ActionsArray;
private readonly IntPtr Handle;
}
}
namespace Aufbauwerk.Tools.FirewallLogRotate
{
public class Program : ServiceBase
{
private static readonly Program _instance = new();
private static void HandleException(Exception? ex, int eventId)
{
try
{
_instance.EventLog.WriteEntry(ex?.ToString(), EventLogEntryType.Error, eventId);
}
finally
{
Environment.Exit(ex?.HResult switch
{
null or 0 => -2147483640, // E_FAIL
int code => code,
});
}
}
private static void Install()
{
// stop and delete any existing service
using (ServiceController oldService = new(_instance.ServiceName))
{
if (oldService.Exists())
{
if (oldService.Status is not ServiceControllerStatus.Stopped)
{
oldService.Stop();
oldService.WaitForStatus(ServiceControllerStatus.Stopped, TimeSpan.FromTicks(TimeSpan.TicksPerMinute));
}
oldService.Delete();
}
}
// copy the binary to the system directory
const string binPath = @"%SystemRoot%\system32\FirewallLogRotate.exe";
string expandedBinPath = Environment.ExpandEnvironmentVariables(binPath);
if (File.Exists(expandedBinPath))
{
File.Delete(expandedBinPath);
}
File.Copy(sourceFileName: typeof(Program).Assembly.Location, destFileName: expandedBinPath);
// create and start the service
using ServiceController newService = ServiceHelper.Create(_instance.ServiceName, new(binPath)
{
DependentServices = new[] { "mpssvc" },
StartType = ServiceStartMode.Automatic,
ServiceType = ServiceType.Win32OwnProcess,
});
newService.SetFailureActions(new(new ServiceFailureAction()
{
Delay = 1,
Type = ServiceFailureActionType.Restart,
}));
if (newService.Status is not ServiceControllerStatus.Running)
{
newService.Start();
newService.WaitForStatus(ServiceControllerStatus.Running, TimeSpan.FromTicks(TimeSpan.TicksPerMinute));
}
// register the new version
Registry.SetValue(@"HKEY_LOCAL_MACHINE\SOFTWARE\AufBauWerk\FirewallLogRotate", "Version", typeof(Program).Assembly.GetName().Version.ToString());
}
private static void Main(string[] args)
{
if (args.Contains("/install", StringComparer.OrdinalIgnoreCase))
{
Install();
}
else
{
AppDomain.CurrentDomain.UnhandledException += (s, e) => HandleException(e.ExceptionObject as Exception, eventId: 4);
Run(_instance);
}
}
private readonly string _logDirectoryPath;
private readonly string _logFileName;
private readonly string _logFilePath;
private int _nextCompressedFileId = 1;
private FileSystemWatcher? _watcher = null;
private Program()
{
ServiceName = "FirewallLogRotate";
CanStop = true;
_logDirectoryPath = Environment.ExpandEnvironmentVariables(@"%systemroot%\system32\LogFiles\Firewall");
_logFileName = "pfirewall.log.old";
_logFilePath = Path.Combine(_logDirectoryPath, _logFileName);
}
private string CompressStreamToFile(Stream stream)
{
// gzip the stream to a temporary file
string tempPath = Path.GetTempFileName();
try
{
using (FileStream outStream = new(tempPath, FileMode.Truncate, FileAccess.Write, FileShare.None))
using (GZipStream gzipStream = new(outStream, CompressionLevel.Optimal))
{
stream.CopyTo(gzipStream);
}
// find the next free file name
string gzPath;
do { gzPath = Path.Combine(_logDirectoryPath, Invariant($"pfirewall.{_nextCompressedFileId++}.gz")); }
while (File.Exists(gzPath));
// move the temporary file to that location
File.Move(sourceFileName: tempPath, destFileName: gzPath);
return gzPath;
}
catch
{
// on error delete the temporary file and re-throw
File.Delete(tempPath);
throw;
}
}
private void EnsureLogDirectoryExists()
{
// ensure the firewall log directory exists with the proper permissions
DirectoryInfo logDirectory = new(_logDirectoryPath);
if (!logDirectory.Exists)
{
DirectorySecurity security = new();
security.AddAccessRule(new(@"NT SERVICE\mpssvc", FileSystemRights.FullControl, InheritanceFlags.ObjectInherit, PropagationFlags.None, AccessControlType.Allow));
security.AddAccessRule(new(new SecurityIdentifier(WellKnownSidType.LocalSystemSid, null), FileSystemRights.FullControl, InheritanceFlags.ObjectInherit, PropagationFlags.None, AccessControlType.Allow));
security.AddAccessRule(new(new SecurityIdentifier(WellKnownSidType.BuiltinAdministratorsSid, null), FileSystemRights.FullControl, InheritanceFlags.ObjectInherit, PropagationFlags.None, AccessControlType.Allow));
logDirectory.Create(security);
}
}
protected override void OnStart(string[] args)
{
ExitCode = ~0;
if (_watcher is not null) throw new InvalidOperationException("Service already started.");
EnsureLogDirectoryExists();
_watcher = new()
{
Path = _logDirectoryPath,
Filter = _logFileName,
EnableRaisingEvents = true,
};
_watcher.Created += (s, e) => RotateFile();
_watcher.Renamed += (s, e) => RotateFile();
_watcher.Error += (s, e) => HandleException(e.GetException(), eventId: 3);
if (File.Exists(_logFilePath))
{
RotateFile();
}
}
protected override void OnStop()
{
if (_watcher is null) throw new InvalidOperationException("Service already stopped.");
_watcher.Dispose();
_watcher = null;
ExitCode = 0;
}
private void RotateFile()
{
try
{
// compress the log file
using FileStream inStream = new(_logFilePath, FileMode.Open, FileAccess.Read, FileShare.Read | FileShare.Delete);
CompressStreamToFile(inStream);
EventLog.WriteEntry("Rotated log file.", EventLogEntryType.Information, eventID: 1);
File.Delete(_logFilePath);
}
catch (Exception ex) { HandleException(ex, eventId: 2); }
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment