Skip to content

Instantly share code, notes, and snippets.

@rheone
Created September 17, 2019 12:50
Show Gist options
  • Save rheone/8082182acdc695aa52f2ba1550c7dfea to your computer and use it in GitHub Desktop.
Save rheone/8082182acdc695aa52f2ba1550c7dfea to your computer and use it in GitHub Desktop.
A Custom NHibernate IUserType for mapping network ordered bytes and an address family to an IPAddress object
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using NHibernate;
using NHibernate.Engine;
using NHibernate.SqlTypes;
using NHibernate.UserTypes;
namespace My.UserType
{
/// <summary>
/// <para>A User Type for persisting an IP Addresses as 2 columns in the DB (string AddressFamily, binary Address)</para>
/// <code>
/// Map(x =&gt; x.Address)
/// .CustomType&amp;lt;IPAddressBinaryUserType&amp;gt;()
/// .Columns.Clear()
/// .Columns.Add("address_family", "address")
/// .Not.Nullable();
/// </code>
/// </summary>
public class IpAddressBinaryUserType : IUserType
{
private const int AddressFamilyColumn = 0;
private const int BytesColumn = 1;
/// <summary>
/// Sql Types for custom mapping
/// </summary>
public SqlType[] SqlTypes => new[]
{
NHibernateUtil.String.SqlType, // Address Family
NHibernateUtil.Binary.SqlType // Address
};
/// <inheritdoc />
public bool IsMutable => false;
/// <inheritdoc />
public Type ReturnedType => typeof(IPAddress);
#region From Interface IUserType
/// <inheritdoc />
public object Assemble(object cached,
object owner)
{
return this.DeepCopy(cached);
}
/// <inheritdoc />
public void NullSafeSet(DbCommand cmd,
object value,
int index,
ISessionImplementor session)
{
if (value is IPAddress address)
{
NHibernateUtil.String.NullSafeSet(cmd, address.AddressFamily.ToString(), index, session);
NHibernateUtil.Binary.NullSafeSet(cmd, address.GetAddressBytes(), index + 1, session);
}
else
{
NHibernateUtil.String.NullSafeSet(cmd, null, index, session);
NHibernateUtil.Binary.NullSafeSet(cmd, null, index + 1, session);
}
}
/// <inheritdoc />
public object DeepCopy(object value)
{
return value;
}
/// <inheritdoc />
public object Disassemble(object value)
{
return this.DeepCopy(value);
}
/// <inheritdoc />
public new bool Equals(object x,
object y)
{
if (ReferenceEquals(x, y))
{
return true;
}
if (x == null
|| y == null)
{
return false;
}
return x.Equals(y);
}
/// <inheritdoc />
public int GetHashCode(object x)
{
return x.GetHashCode();
}
/// <inheritdoc />
/// <exception cref="HibernateException">on inability to do get based on parsing or column mapping</exception>
public object NullSafeGet(DbDataReader rs,
string[] names,
ISessionImplementor session,
object owner)
{
var addressFamilyString = NHibernateUtil.String.NullSafeGet(rs, names[AddressFamilyColumn], session) as string;
var bytes = NHibernateUtil.Binary.NullSafeGet(rs, names[BytesColumn], session) as byte[];
// no value
if (string.IsNullOrWhiteSpace(addressFamilyString)
&& bytes == null)
{
return null;
}
// get the address family
if (!Enum.TryParse((addressFamilyString ?? string.Empty).Trim(),
true,
out AddressFamily addressFamily))
{
throw new HibernateException($"Address family of \"{addressFamilyString}\" named by mapping as column \"{names[AddressFamilyColumn]}\" is not recognized");
}
if (bytes == null)
{
throw new HibernateException($"Bytes named by mapping as column \"{names[BytesColumn]}\" is not recognized");
}
var address = new IPAddress(GetBytes(addressFamily, bytes));
if (address.AddressFamily != addressFamily)
{
throw new HibernateException($"Expecting IP Addresses to be of address family {addressFamily}");
}
return address;
}
/// <inheritdoc />
public object Replace(object original,
object target,
object owner)
{
return original;
}
#endregion
/// <summary>
/// Get the bytes of an IP Address accounting for length based on address family
/// </summary>
/// <param name="addressFamily">the desired address family</param>
/// <param name="input">the possibly truncated bytes that may required 0 bytes affixed</param>
private static byte[] GetBytes(AddressFamily addressFamily,
IEnumerable<byte> input)
{
switch (addressFamily)
{
case AddressFamily.InterNetwork:
return AffixByteLength(input, 4);
case AddressFamily.InterNetworkV6:
return AffixByteLength(input, 16);
default:
throw new HibernateException($"Address family \"{addressFamily}\" is unsupported");
}
}
/// <summary>
/// Transform an <see cref="Enumerable" /> of <see langword="byte" /> input to a given length, trimming LSB / padding
/// with LSB 0x00\"s as necessary
/// </summary>
/// <param name="input">the bytes to transform</param>
/// <param name="desiredLength">the length of the bytes</param>
/// <returns>the transformed bytes</returns>
private static byte[] AffixByteLength(IEnumerable<byte> input,
int desiredLength)
{
var inputArray = (input ?? Enumerable.Empty<byte>()).ToArray();
if (inputArray.Length > desiredLength)
{
return inputArray.Take(desiredLength)
.ToArray();
}
if (inputArray.Length < desiredLength)
{
return inputArray.Concat(Enumerable.Repeat((byte) 0x00, desiredLength - inputArray.Length))
.ToArray();
}
return inputArray;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment