28 January 2013

Reimplementace dynamic_cast v C++

Článek uvede jednu z možností, jak efektivně nahradit funkcionalitu dynamic_cast v C++. Prezentovaný kód bude využívat nejnovější specifikace jazyka, C++11, ovšem s malým omezením a trochou snahy není problém kód přepsat pro C++03.

Přetypování bázové třídy na třídu odvozenou lze v C++ docílit dvěma způsoby - pomocí static_cast, pokud patříme k těm šťastnějším a přetypování je možné provést ještě v době kompilace, či pomocí dynamic_cast, kdy máme fakt pech a musíme sáhnout po RTTI a přetypování provést za běhu programu. Jedním z takových případů je pokus o přetypování z virtuální bázové třídy.

Zlé jazyky tvrdí, že potřeba dynamic_cast je známka špatného návrhu aplikace. Dovolím si lehce nesouhlasit a poukázat na existenci několika případů, kdy je dynamic_cast užitečný. Představte si, že pracujete s nějakou logovací funkcí, která má za úkol ukládat informace o jednotlivých událostech. Událost může mít svou bázovou třídu, řekněme Event, ze které dědí všechny typy událostí. Díky dynamic_cast můžeme o konkrétních událostech získat bližší informace - událost přetypujeme z bázové třídy Event na konkrétní třídu události, a poté z ní můžeme získat více informací.


V úvodu jsem psal o přetypování z virtuální bázové třídy na třídu odvozenou. Mnoho lidí se virtuální dědičnosti v C++ vyhýbá, ovšem může být užitečná v případě, kdy píšete nějaký framework bohatý na rozhraní. Osobně se mi v praxi osvědčilo dodržovat pravidlo, aby každé rozhraní (či abstraktní třída, chcete-li), které využívá vlastností jiného rozhraní, z něj dědilo virtuálně. Zároveň je dobrou praxí mít základní třídu pro všechny ostatní třídy, včetně rozhraní, např. Object. Toho mnoho známých projektů v C++ skutečně využívá.

Důvodem virtuální dědičnosti je zamezení vícenásobného vytvoření instance bázových tříd. Modelovou situací může být případ, kdy v projektu existují třídy ObjectInputStream, OutputStream a IOStream, který využívá vlastností dvou předchozích tříd, které zároveň dědí ze třídy Object. Pokud bychom nevyužili virtuální dědičnosti u InputStream a OutputStream, objekt IOStream by obsahoval dvě instance třídy Object. Pokud bychom se poté z objektu IOStream pokusili zavolat metodu GetHandle(), kompilátor by zahlásil chybu, jelikož by nevěděl, jakou ze dvou existujících metod zavolat. Díky virtuální dědičnosti existuje instance jen jedna, sdílená napříč všemi objekty uvnitř každé ze tříd. Tomuto případu se také říká diamond inheritance.

Příklad nevyužití virtuální dědičnosti


Příklad využití virtuální dědičnosti

Jak již bylo uvedeno, problém nastane ve chvíli, když bude potřeba přetypovat virtuální třídu na některou z odvozených tříd - Object na jakoukoliv jinou třídu - právě kvůli virtuální dědičnosti. Tento problém lze jednoduše vyřešit právě použitím dynamic_cast. Pro ty, kteří se dynamic_cast z jakéhokoliv důvodu štítí, uvedu alternativní způsob, jak chování dynamic_cast při splnění určitých pravidel reimplementovat.

Ukázka, jak by mohla v paměti vypadat vytvořená instance třídy IOStream - z obrázku je zřejmé, že neexistuje žádný primitivní způsob, jak Object přetypovat na nadtřídu.

Jedním z těchto pravidel právě je, že musí existovat základní bázová třída, ze které budou dědit všechny ostatní třídy, na které budeme chtít aplikovat dynamické přetypování. V našem případě třída Object. Účel této třídy bude zejména poskytnutí informací o odvozené třídě (její název) a dvojice metod To(const std::string&), které budou využity k přetypovávání na základě názvu třídy - jedna const a druhá non-const. Srdcem celého principu bude hash-tabulka uložená ve třídě Object, která má za úkol udržovat ukazatele na jednotlivé odvozené podtřídy na základě jejich názvů. Tabulka musí být označena klíčovým slovem mutable, abychom ji mohli editovat i z const metod.
class Object
{
    public:
        virtual ~Object();

        static const std::string& GetClassNameStatic();
        virtual const std::string& GetClassName() const;
    
        void* To(const std::string& className);
        const void* To(const std::string& className) const;

    private:
        mutable std::map<std::string, void*> _bases;
};
Tento samotný návrh samozřejmě není dostačující. Je potřeba zajistit metodu pro zaregistrování odvozené třídy a zvlášť metodu pro rekurzivní procházení celou hierarchií dědičnosti. Kód jsem rozdělil na dvě části z důvodu snazšího pochopení.
class Object
{
    public:
        virtual ~Object() { }

        static const std::string& GetClassNameStatic()
        {
            static std::string className("Object");
            return className;
        }

        virtual const std::string& GetClassName() const
        {
            return GetClassNameStatic();
        }
    
        void* To(const std::string& className)
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

        const void* To(const std::string& className) const
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

    protected:
        void RegisterSubclass(const void* ptr, const std::string& className) const
        {
            _bases[className] = const_cast<void*>(ptr);
        }

        virtual void RegisterAllSubclasses() const
        {
            RegisterSubclass(static_cast<const void*>(this), Object::GetClassName());
        }

    private:
        mutable std::map<std::string, void*> _bases;
};
Na první pohled je zřejmé, jaký bude princip. Hierarchie dědičnosti se nevytvoří do té doby, dokud nebude zavolána metoda To(), z čehož plyne výhoda, že tabulka nebude vytvořena do té doby, dokud to nebude potřeba. Použití konstrukce if (_bases.size() == 0) je bezpečné, jelikož v tabulce je vždy minimálně přítomen sám Object.

Pro donucení odvozené třídy k registraci sebe sama a svých předků, využijeme překrytí metod RegisterAllSubclasses() a GetClassName() současně s novou definicí statické metody GetClassNameStatic(). Význam vytvoření nové definice této statické metody bude vysvětlen později. Třída InputStream by pak mohla vypadat nějak takto:
class InputStream : virtual public Object
{
    public:
        static const std::string& GetClassNameStatic()
        {
            static std::string className("InputStream");
            return className;
        }

        virtual const std::string& GetClassName() const override
        {
            return GetClassNameStatic();
        }

        void Read();

    protected:
        virtual void RegisterAllSubclasses() const override
        {
            RegisterSubclass(static_cast<const void*>(this), InputStream::GetClassName());
            Object::RegisterAllSubclasses();
        }
};
V metodě RegisterAllSubclasses() je u volaných metod důležité uvádět konkrétní názvy tříd. V případě, kdy by nebyly uvedeny, by se zavolaly metody překryté nejvíce derivovanou třídou (kdyby ze třídy InputStream dědil IOStream, metoda GetClassName() by vždy vracela název třídy IOStream). Poté zavoláme metodu RegisterAllSubclasses() na všech předcích, které opět rekurzivně zaregistrují své předky. Metoda RegisterAllSubclasses() by ve třídě IOStream vypadala následovně:
        virtual void RegisterAllSubclasses() const override
        {
            RegisterSubclass(static_cast<const void*>(this), IOStream::GetClassName());
            InputStream::RegisterAllSubclasses();
            OutputStream::RegisterAllSubclasses();
        }
Opakovaný zápis těchto metod v každé třídě zvlášť není samozřejmě pohodlný, proto je výhodnější přenechat tuto práci makru. Makro bude přebírat název aktuální třídy a za ním bude následovat seznam všech předků (variadic macro). Samotné registrování předků bude využívat malého triku s variadic templates, které C++11 zavedlo.
#define DEFINE_BASES(class, ...)                                                \
    static const std::string& GetClassNameStatic()                              \
    {                                                                           \
        static std::string className(#class);                                   \
        return className;                                                       \
    }                                                                           \
                                                                                \
    const std::string& GetClassName() const override                            \
    {                                                                           \
        return GetClassNameStatic();                                            \
    }                                                                           \
                                                                                \
    template <typename _empty>                                                  \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
                                                                                \
    }                                                                           \
                                                                                \
    template <typename _empty, typename T, typename... Args>                    \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
        T::RegisterAllSubclasses();                                             \
        RegisterAllSubclasses<void, Args...>();                                 \
    }                                                                           \
                                                                                \
    virtual void RegisterAllSubclasses() const override                         \
    {                                                                           \
        RegisterSubclass(static_cast<const void*>(this), class::GetClassName());\
        RegisterAllSubclasses<void, __VA_ARGS__>();                             \
    }
Princip je jednoduchý - bázové třídy se budou rekurzivně registrovat. Až se všechny zaregistrují a parametr Args bude prázdný, rekurze se ukončí voláním "prázdné" metody template <typename _empty> void RegisterAllSubclasses(). V případě výše uvedeného případu s třídou IOStream bude průběh volání následující:


Nakonec zbývá jediná věc - samotná funkce pro dynamické přetypování. Nyní přichází na řadu výhoda definované statické metody pro zjištění názvu třídy. Funkce totiž požádá o název třídy předané přes šablonu právě skrze statickou funkci GetClassNameStatic() a předá ji metodě To(const std::string&). Ta se v hash-tabulce pokusí najít ukazatel na konkrétní podtřídu, kterou pokud najde, vrátí ji.
template <typename T>
T my_dynamic_cast(Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}

template <typename T>
T my_dynamic_cast(const Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}
std::remove_pointer<T> někteří znají z knihovny Boost, v C++11 je ale součástí standardní knihovny ve hlavičkovém souboru type_traits. Důvod, proč se ve funkci nachází, je zřejmý - funkce potřebuje získat původní datový typ třídy, zatímco samotné funkci je předáván pouze ukazatel na ni.

Na závěr dodám, že návrhů na rozšíření je mnoho. Mimo urychlení funkcionality zaměněním indexování na základě samotného hashe textového řetězce (díky čemuž je tato alternativa až 10x rychlejší než standardní dynamic_cast) lze tuto kostru projektu rozšířit na plnohodnotné zavedení vlastního RTTI do projektu.

Nyní už jen samotný zdrojový kód:
#include <iostream>
#include <string>
#include <map>
#include <type_traits>

#define DEFINE_BASES(class, ...)                                                \
    static const std::string& GetClassNameStatic()                              \
    {                                                                           \
        static std::string className(#class);                                   \
        return className;                                                       \
    }                                                                           \
                                                                                \
    const std::string& GetClassName() const override                            \
    {                                                                           \
        return GetClassNameStatic();                                            \
    }                                                                           \
                                                                                \
    template <typename _empty>                                                  \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
                                                                                \
    }                                                                           \
                                                                                \
    template <typename _empty, typename T, typename... Args>                    \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
        T::RegisterAllSubclasses();                                             \
        RegisterAllSubclasses<void, Args...>();                                 \
    }                                                                           \
                                                                                \
    virtual void RegisterAllSubclasses() const override                         \
    {                                                                           \
        RegisterSubclass(static_cast<const void*>(this), class::GetClassName());\
        RegisterAllSubclasses<void, __VA_ARGS__>();                             \
    }

class Object
{
    public:
        virtual ~Object() { }

        static std::string& GetClassNameStatic()
        {
            static std::string className("Object");
            return className;
        }

        virtual const std::string& GetClassName() const
        {
            return GetClassNameStatic();
        }

        void* To(const std::string& className) 
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

        const void* To(const std::string& className) const
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

    protected:
        void RegisterSubclass(const void* ptr, const std::string& className) const
        {
            _bases[className] = const_cast<void*>(ptr);
        }

        virtual void RegisterAllSubclasses() const
        {
            RegisterSubclass(static_cast<const void*>(this), Object::GetClassName());
        }

    private:
        mutable std::map<std::string, void*> _bases;
};

////////////////////////////////////////////////////////////////////////////////
 
template <typename T>
T my_dynamic_cast(Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}

template <typename T>
T my_dynamic_cast(const Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}
 
////////////////////////////////////////////////////////////////////////////////
 
class InputStream : virtual public Object
{
    public:
        DEFINE_BASES(InputStream, Object);
        void Read() { }
};
 
class OutputStream : virtual public Object
{
    public:
        DEFINE_BASES(OutputStream, Object);
        void Write() { }
};
 
class IOStream : public InputStream, public OutputStream
{
    int _value;

    public:
        DEFINE_BASES(IOStream, InputStream, OutputStream);
        IOStream() : _value(0) { }

        int GetValue() const { return _value; }
        void SetValue(int value) { _value = value; }
};
 
int main()
{
    const Object*   co = new IOStream;
    const IOStream* cd = my_dynamic_cast<const IOStream*>(co);
    
    Object*   o = new IOStream;
    IOStream* d = my_dynamic_cast<IOStream*>(o);

    d->SetValue(42);
    
    printf("const:     %i, %p, %p\n", cd->GetValue(), co, cd);
    printf("non-const: %i, %p, %p\n", d->GetValue(), o, d);
    
    delete cd;
    delete d;

    return 0;
}

No comments:

Post a Comment