Some more work to prevent DLL pre-loading attacks. Full protection is only enabled in "static" builds. Non-static builds require that we allow DLL loading from application install directory (e.g. to load the Qt plug-ins).

This commit is contained in:
LoRd_MuldeR 2018-02-18 13:17:17 +01:00
parent 4671aadcea
commit aa98a2157b

View File

@ -54,6 +54,33 @@ static LONG WINAPI my_exception_handler(struct _EXCEPTION_POINTERS *ExceptionInf
return LONG_MAX;
}
///////////////////////////////////////////////////////////////////////////////
// DEFAULT DLL DIRECTORIES
///////////////////////////////////////////////////////////////////////////////
//Flags
#define MY_LOAD_LIBRARY_SEARCH_APPLICATION_DIR 0x200
#define MY_LOAD_LIBRARY_SEARCH_USER_DIRS 0x400
#define MY_LOAD_LIBRARY_SEARCH_SYSTEM32 0x800
#ifdef MUTILS_STATIC_LIB
#define MY_LOAD_LIBRARY_FLAGS (MY_LOAD_LIBRARY_SEARCH_SYSTEM32 | MY_LOAD_LIBRARY_SEARCH_USER_DIRS)
#else
#define MY_LOAD_LIBRARY_FLAGS (MY_LOAD_LIBRARY_SEARCH_SYSTEM32 | MY_LOAD_LIBRARY_SEARCH_USER_DIRS | MY_LOAD_LIBRARY_SEARCH_APPLICATION_DIR)
#endif
static void set_default_dll_directories(void)
{
typedef BOOL(__stdcall *MySetDefaultDllDirectories)(const DWORD DirectoryFlags);
if (const HMODULE kernel32 = GetModuleHandleW(L"kernel32"))
{
if (const MySetDefaultDllDirectories pSetDefaultDllDirectories = (MySetDefaultDllDirectories)GetProcAddress(kernel32, "SetDefaultDllDirectories"))
{
pSetDefaultDllDirectories(MY_LOAD_LIBRARY_FLAGS);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// SETUP ERROR HANDLERS
///////////////////////////////////////////////////////////////////////////////
@ -64,15 +91,17 @@ void MUtils::ErrorHandler::initialize(void)
SetUnhandledExceptionFilter(my_exception_handler);
SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
_set_invalid_parameter_handler(my_invalid_param_handler);
SetDllDirectoryW(L""); /*don'tload DLL from "current" directory*/
/*to prevent DLL pre-loading attacks*/
set_default_dll_directories();
SetDllDirectoryW(L"");
static const int signal_num[6] = { SIGABRT, SIGFPE, SIGILL, SIGINT, SIGSEGV, SIGTERM };
for(size_t i = 0; i < 6; i++)
{
signal(signal_num[i], my_signal_handler);
}
}
///////////////////////////////////////////////////////////////////////////////