Skip to content

Instantly share code, notes, and snippets.

@caiorss
Forked from 43x2/IUnknownPointerTable.cpp
Created August 21, 2018 13:45
Show Gist options
  • Save caiorss/7c821f0363472f0d654f79f5e604b167 to your computer and use it in GitHub Desktop.
Save caiorss/7c821f0363472f0d654f79f5e604b167 to your computer and use it in GitHub Desktop.
// C/C++ Common
#include <cstdio>
#include <map>
// Windows
#include <windows.h>
class __declspec( uuid( "3D48D1EB-1EA5-4D4B-8411-4C0B8DF72CD7" ) ) ITestClass : public IUnknown
{
protected:
static const IID IID_ITestClass;
};
const IID ITestClass::IID_ITestClass = { 0x3D48D1EB, 0x1EA5, 0x4D4B, { 0x84, 0x11, 0x4C, 0x0B, 0x8D, 0xF7, 0x2C, 0xD7 } };
class TestClass : public ITestClass
{
public:
static TestClass * CreateInstance()
{
TestClass * ptr = new TestClass();
return ptr;
}
virtual ULONG STDMETHODCALLTYPE AddRef()
{
ULONG count = InterlockedIncrement( &m_refCount );
return count;
}
virtual ULONG STDMETHODCALLTYPE Release()
{
ULONG count = InterlockedDecrement( &m_refCount );
if ( m_refCount == 0 )
{
delete this;
}
return count;
}
virtual HRESULT STDMETHODCALLTYPE QueryInterface( REFIID riid, void ** ppvObject )
{
if ( !ppvObject ) return E_POINTER;
if ( riid == IID_IUnknown || riid == IID_ITestClass )
{
*ppvObject = this;
AddRef();
return S_OK;
}
*ppvObject = nullptr;
return E_NOINTERFACE;
}
private:
TestClass()
: m_refCount( 1 ) // AddRef 省略
{
}
~TestClass()
{
}
__declspec( align( 4 ) ) ULONG m_refCount;
};
template< typename KeyType > class IUnknownPointerTable
{
private:
typedef std::map< KeyType, IUnknown * > map_type;
public:
IUnknownPointerTable()
: m_table()
{
}
private:
IUnknownPointerTable( const IUnknownPointerTable & );
IUnknownPointerTable & operator=( const IUnknownPointerTable & );
public:
~IUnknownPointerTable()
{
for ( auto it = m_table.begin(); it != m_table.end(); ++it )
{
// map からの参照を解除
it->second->Release();
}
}
public:
// 追加
bool Add( const KeyType key, IUnknown * pointer )
{
if ( !pointer ) return false;
auto result = m_table.insert( typename map_type::value_type( key, pointer ) );
if ( result.second )
{
pointer->AddRef(); // map が参照するので
}
return result.second;
}
// 削除 (pointer が nullptr なら削除したポインタは返さない)
bool Remove( const KeyType key, IUnknown ** pointer )
{
map_type::iterator it = m_table.find( key );
if ( it == m_table.end() ) return false;
IUnknown * unknown = it->second;
m_table.erase( it ); // 削除
unknown->Release(); // map から参照されていたぶん
if ( pointer ) *pointer = unknown;
return true;
}
// 検索 (pointer が nullptr ならキーに値があるかどうかだけ返す)
bool Find( const KeyType key, IUnknown ** pointer )
{
map_type::iterator it = m_table.find( key );
if ( it == m_table.end() ) return false;
if ( pointer )
{
*pointer = it->second;
( *pointer )->AddRef(); // 呼び出し元から参照されるので
}
return true;
}
// 検索 (別インターフェース経由)
template< typename PointerType >
bool Find( const KeyType key, PointerType ** pointer )
{
// better: assert( pointer )
IUnknown * unknown = nullptr;
if ( !Find( key, &unknown ) ) return false;
PointerType * pt = nullptr;
if ( SUCCEEDED( unknown->QueryInterface( &pt ) ) )
{
unknown->Release(); // Find のぶん
*pointer = pt;
return true;
}
return false;
}
private:
map_type m_table;
};
int main( void )
{
// 追加テスト
std::printf( "Add:\n" );
{
TestClass * tc1 = TestClass::CreateInstance();
ITestClass * tc2 = TestClass::CreateInstance();
IUnknown * tc3 = TestClass::CreateInstance();
{
IUnknownPointerTable< int > pt;
std::printf( "%d (Expected: 1)\n", pt.Add( 0, tc1 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 1, tc2 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 2, tc3 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 3, tc1 ) );
std::printf( "%d (Expected: 0)\n", pt.Add( 0, tc2 ) );
}
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc1->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc2->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc3->Release() ) );
}
std::printf( "\n" );
// 削除テスト
std::printf( "Remove:\n" );
{
TestClass * tc1 = TestClass::CreateInstance();
ITestClass * tc2 = TestClass::CreateInstance();
IUnknown * tc3 = TestClass::CreateInstance();
{
IUnknownPointerTable< int > pt;
std::printf( "%d (Expected: 1)\n", pt.Add( 0, tc1 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 1, tc2 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 2, tc3 ) );
std::printf( "%d (Expected: 1)\n", pt.Remove( 0, nullptr ) );
std::printf( "%d (Expected: 0)\n", pt.Remove( 0, nullptr ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 0, tc2 ) );
std::printf( "%d (Expected: 1)\n", pt.Remove( 1, nullptr ) );
std::printf( "%d (Expected: 0)\n", pt.Remove( 1, nullptr ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 1, tc3 ) );
std::printf( "%d (Expected: 1)\n", pt.Remove( 2, nullptr ) );
std::printf( "%d (Expected: 0)\n", pt.Remove( 2, nullptr ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 2, tc1 ) );
IUnknown * unknown = nullptr;
std::printf( "%d (Expected: 1)\n", pt.Remove( 0, &unknown ) );
std::printf( "%p (Expected: %p)\n", unknown, tc2 );
std::printf( "%d (Expected: 1)\n", pt.Remove( 1, &unknown ) );
std::printf( "%p (Expected: %p)\n", unknown, tc3 );
std::printf( "%d (Expected: 1)\n", pt.Remove( 2, &unknown ) );
std::printf( "%p (Expected: %p)\n", unknown, tc1 );
}
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc1->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc2->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc3->Release() ) );
}
std::printf( "\n" );
// 検索テスト
std::printf( "Find:\n" );
{
TestClass * tc1 = TestClass::CreateInstance();
ITestClass * tc2 = TestClass::CreateInstance();
IUnknown * tc3 = TestClass::CreateInstance();
{
IUnknownPointerTable< int > pt;
std::printf( "%d (Expected: 1)\n", pt.Add( 0, tc1 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 1, tc2 ) );
std::printf( "%d (Expected: 1)\n", pt.Add( 2, tc3 ) );
IUnknown * piu = nullptr;
ITestClass * pitc = nullptr;
std::printf( "%d (Expected: 1)\n", pt.Find( 0, &piu ) );
std::printf( "%p (Expected: %p)\n", piu, tc1 );
std::printf( "%d (Expected: 1)\n", pt.Find( 0, &pitc ) );
std::printf( "%p (Expected: %p)\n", pitc, tc1 );
piu->Release();
pitc->Release();
std::printf( "%d (Expected: 1)\n", pt.Find( 1, &piu ) );
std::printf( "%p (Expected: %p)\n", piu, tc2 );
std::printf( "%d (Expected: 1)\n", pt.Find( 1, &pitc ) );
std::printf( "%p (Expected: %p)\n", pitc, tc2 );
piu->Release();
pitc->Release();
std::printf( "%d (Expected: 1)\n", pt.Find( 2, &piu ) );
std::printf( "%p (Expected: %p)\n", piu, tc3 );
std::printf( "%d (Expected: 1)\n", pt.Find( 2, &pitc ) );
std::printf( "%p (Expected: %p)\n", pitc, tc3 );
piu->Release();
pitc->Release();
std::printf( "%d (Expected: 1)\n", pt.Find( 0, nullptr ) );
}
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc1->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc2->Release() ) );
std::printf( "%u (Expected: 0)\n", static_cast< unsigned int >( tc3->Release() ) );
}
std::printf( "\n" );
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment