#include <windows.h>
#include <atlbase.h>
#include <statreg.h>
#include "dll_hw.h"
#include "dll_hw_i.c"
#include "resource.h"

void RegisterServer(wchar_t* widePath, bool reg) {
    ATL::CRegObject ro;
    ro.AddReplacement(L"Module", widePath);
    reg ? ro.ResourceRegister(widePath, IDR_REGISTRY, L"REGISTRY") :
        ro.ResourceUnregister(widePath, IDR_REGISTRY, L"REGISTRY");
}

HINSTANCE hInstance;

int _stdcall DllMain(HINSTANCE hInstance, DWORD reason, void*) {
    ::hInstance = hInstance;
    return 1;
}

void GetPathName(wchar_t* widePath) {
    char ansiPath[MAX_PATH];
    GetModuleFileName(hInstance, ansiPath, MAX_PATH);
    MultiByteToWideChar(CP_ACP, 0, ansiPath, lstrlen(ansiPath) + 1, widePath, MAX_PATH);
}

extern "C" HRESULT _stdcall DllRegisterServer() {
    wchar_t widePath[MAX_PATH];
    GetPathName(widePath);
    RegisterServer(widePath, true);
    CComPtr<ITypeLib> pTypeLib;
    HRESULT hr(LoadTypeLib(widePath, &pTypeLib));
    if(hr) return SELFREG_E_TYPELIB;
    hr = RegisterTypeLib(pTypeLib, widePath, 0);
    if(hr) return SELFREG_E_TYPELIB;
    return S_OK;
}

extern "C" HRESULT _stdcall DllUnregisterServer() {
    wchar_t widePath[MAX_PATH];
    GetPathName(widePath);
    RegisterServer(widePath, false);
    HRESULT hr(UnRegisterTypeLib(LIBID_SeanTestLibrary, 1, 0, 0, SYS_WIN32));
    if(hr) return SELFREG_E_TYPELIB;
    return S_OK;
}


ULONG globalCount;

struct EVeryBadThing { };
class CSeanTestObject : public ISeanTestInterface {
	ULONG _refCount;
	CComPtr<ITypeInfo> _pTypeInfo;
	bool _dispatchCall;
public:
	CSeanTestObject() throw(EVeryBadThing) : _refCount(0), _dispatchCall(false) { 
		CComPtr<ITypeLib> pTypeLib;
		HRESULT hr;
		hr = LoadRegTypeLib(LIBID_SeanTestLibrary, 1, 0, 0, &pTypeLib);
		if(hr) throw EVeryBadThing();
		CComPtr<ITypeInfo> pTypeInfo;
		hr = pTypeLib->GetTypeInfoOfGuid(IID_ISeanTestInterface, &_pTypeInfo);
		if(hr) throw EVeryBadThing();
	}
	virtual ULONG _stdcall AddRef() { 
		++globalCount; 
		return ++_refCount; 
	}
	virtual ULONG _stdcall Release() { 
		--globalCount;
		ULONG ret(--_refCount);
		if(!ret) delete this;
		return ret;
	}
	virtual HRESULT _stdcall QueryInterface(REFIID riid, void** ppv) {
		if(!ppv) return E_POINTER;
		if(riid == IID_IUnknown) *ppv = static_cast<IUnknown*>(this);
		else if(riid == IID_IDispatch) *ppv = static_cast<IDispatch*>(this);
		else if(riid == IID_ISeanTestInterface) *ppv = static_cast<ISeanTestInterface*>(this);
		else return *ppv = 0, E_NOINTERFACE;
		return AddRef(), S_OK;
	}

	virtual HRESULT _stdcall GetIDsOfNames(REFIID riid, OLECHAR** rgszNames, UINT cNames,
		LCID lcid, DISPID* rgDispId) {
		if(lcid == 0 || lcid == 9 || lcid == 0x409)
			return _pTypeInfo->GetIDsOfNames(rgszNames, cNames, rgDispId);
		return DISP_E_UNKNOWNLCID;
	}
	virtual HRESULT _stdcall GetTypeInfo(UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo) {
		if(iTInfo) return TYPE_E_ELEMENTNOTFOUND;
		if(!ppTInfo) return E_POINTER;
		if(lcid == 0 || lcid == 9 || lcid == 0x409) 
			return _pTypeInfo.CopyTo(ppTInfo), S_OK;
		return *ppTInfo = 0, DISP_E_UNKNOWNLCID;
	}
	virtual HRESULT _stdcall GetTypeInfoCount(UINT* pctinfo) {
		if(!pctinfo) return E_POINTER;
		return *pctinfo = 1, S_OK;
	}
	virtual HRESULT _stdcall Invoke(DISPID dispIdMember, REFIID riid, LCID lcid, WORD wFlags,
		DISPPARAMS* pDispParams, VARIANT* pVarResult, EXCEPINFO* pExcepInfo, UINT* puArgError) {
		if(lcid == 0 || lcid == 9 || lcid == 0x409) {
			_dispatchCall = true;
			return _pTypeInfo->Invoke(this, dispIdMember, wFlags, pDispParams, pVarResult, 
			pExcepInfo, puArgError);
		}
		return DISP_E_UNKNOWNLCID;
	} 
	virtual HRESULT _stdcall TestMethod() {
		MessageBox(0, "It was people!\nPeople soiled our green!\n\n\t- Ned Flanders", 
			_dispatchCall ? "Dispatch Call" : "Virtual Call", MB_OK | MB_ICONINFORMATION);
		_dispatchCall = false;
		return S_OK;
	}
};

class CSeanTestObjectFactory : public IClassFactory {
	ULONG _refCount;
public:
	CSeanTestObjectFactory() : _refCount(0) { }
	virtual ULONG _stdcall AddRef() { 
		++globalCount; 
		return ++_refCount; 
	}
	virtual ULONG _stdcall Release() {
		--globalCount;
		ULONG ret(--_refCount);
		if(!ret) delete this;
		return ret;
	}
	virtual HRESULT _stdcall QueryInterface(REFIID riid, void** ppv) {
		if(!ppv) return E_POINTER;
		if(riid == IID_IUnknown) *ppv = static_cast<IUnknown*>(this);
		else if(riid == IID_IClassFactory) *ppv = static_cast<IClassFactory*>(this);
		else return *ppv = 0, E_NOINTERFACE;
		return AddRef(), S_OK;
	}
	virtual HRESULT _stdcall CreateInstance(IUnknown* pUnk, REFIID riid, void** ppv) {
		try {
			if(pUnk) return CLASS_E_NOAGGREGATION;
			CSeanTestObject* seanTestObject = new CSeanTestObject;
			HRESULT hr(seanTestObject->QueryInterface(riid, ppv));
			if(hr) delete seanTestObject;
			return hr;
		} catch(EVeryBadThing) { return E_UNEXPECTED; }
	}
	virtual HRESULT _stdcall LockServer(BOOL lock) { return S_OK; }
};
		
extern "C" HRESULT _stdcall DllGetClassObject(REFCLSID rclsid, REFIID riid, void** ppv) {
	if(rclsid != CLSID_SeanTestCoClass) return CLASS_E_CLASSNOTAVAILABLE;
	CSeanTestObjectFactory* testObjectFactory = new CSeanTestObjectFactory;
	HRESULT hr(testObjectFactory->QueryInterface(riid, ppv));
	if(hr) delete testObjectFactory;
	return hr;
}

extern "C" HRESULT _stdcall DllCanUnloadNow() {
	return globalCount ? S_FALSE : S_OK;
}