Skip to content

Instantly share code, notes, and snippets.

@maxmcguire
Created January 13, 2017 05:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxmcguire/b7ed7954d6271bac2b49ddea2fc3d87a to your computer and use it in GitHub Desktop.
Save maxmcguire/b7ed7954d6271bac2b49ddea2fc3d87a to your computer and use it in GitHub Desktop.
void CorrectFileCase(const wchar_t* srcFileName, wchar_t* dstFileName, int maxLength)
{
struct FILE_NAME_INFORMATION
{
ULONG FileNameLength;
WCHAR FileName[1024 + 1];
};
typedef NTSTATUS (NTAPI *_NtQueryInformationFile)(HANDLE, PIO_STATUS_BLOCK, PVOID, ULONG, FILE_INFORMATION_CLASS);
bool useFallback = false;
DWORD flagsAndAttributes = FILE_FLAG_BACKUP_SEMANTICS; // Allows us to open directories.
HANDLE hFile = CreateFile(srcFileName, GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, flagsAndAttributes, nullptr);
bool result = false;
if (hFile != INVALID_HANDLE_VALUE)
{
static _NtQueryInformationFile NtQueryInformationFile = nullptr;
if (NtQueryInformationFile == nullptr)
{
HMODULE hDll = LoadLibraryW(L"ntdll.dll");
NtQueryInformationFile = (_NtQueryInformationFile)GetProcAddress(hDll, "NtQueryInformationFile");
if (NtQueryInformationFile == nullptr)
{
Log_Error(2, "Couldn't get NtQueryInformationFile function");
}
}
IO_STATUS_BLOCK iosb;
FILE_NAME_INFORMATION nameInformation;
NTSTATUS status = NtQueryInformationFile(hFile, &iosb, &nameInformation,
sizeof(nameInformation), (FILE_INFORMATION_CLASS)9); // FileNameInformation
CloseHandle(hFile);
if (status == 0)
{
nameInformation.FileName[nameInformation.FileNameLength / sizeof(WCHAR)] = 0;
// Fix up the slashes.
for (int i = 0; nameInformation.FileName[i] != 0; ++i)
{
if (nameInformation.FileName[i] == L'\\')
{
nameInformation.FileName[i] = L'/';
}
}
dstFileName[0] = 0;
int length = 0;
// We don't get the volume label, so just use that from the original file name.
// There are ways of getting the proper volume label case, but it's more expensive and complex.
const WCHAR* volumeLabelEnd = wcschr(srcFileName, L':');
if (volumeLabelEnd != nullptr)
{
length = volumeLabelEnd - srcFileName + 1;
wcsncpy(dstFileName, srcFileName, volumeLabelEnd - srcFileName + 1);
}
int appendLength = Min(maxLength - length - 1, nameInformation.FileNameLength);
wcsncpy(dstFileName + length, nameInformation.FileName, appendLength);
dstFileName[appendLength] = 0;
int srcLength = wcslen(srcFileName);
int dstLength = wcslen(dstFileName);
if (srcLength > 0 && srcFileName[srcLength - 1] == L'/' || srcFileName[srcLength - 1] == L'\\')
{
// Make sure we have a trailing slash.
if (dstLength > 0 && dstFileName[dstLength - 1] != '/')
{
if (dstLength + 1 < maxLength)
{
dstFileName[dstLength] = '/';
dstFileName[dstLength + 1] = 0;
}
}
}
}
else
{
Log_Message(2, "Using fallback method for testing file name case");
useFallback = true;
}
}
else if (GetLastError() != ERROR_FILE_NOT_FOUND)
{
useFallback = true;
}
if (useFallback)
{
wchar_t shortFileName[1025];
if (GetShortPathName(srcFileName, shortFileName, countof(shortFileName) - 1) != 0)
{
GetLongPathName(shortFileName, dstFileName, maxLength - 1);
}
else
{
wcsncpy(dstFileName, srcFileName, maxLength - 1);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment