A simple example to try/learn the Visitor Pattern in C++
-
The original visitor pattern code written in c++: visitor.cpp. Running it you would get:
virtual int Visitor1::visit(Visitable1&) virtual int BaseVisitor::visit(Visitable1&) virtual int Visitor1::visit(Visitable2&) virtual int Visitor2::visit(Visitable2&)
Perfect!
-
Now try to modify it to use smart pointer: wrong_shared_ptr_visitor.cpp. However, running it you would only get:
virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>)
Because only
BaseVisitable
inheritsstd::enable_shared_from_this<BaseVisitable>
, callingshared_from_this()
from inside classVisitable1
andVisitable2
would returnshared_ptr<BaseVisitable>
. Need to downcast to getshared_ptr<Visitable1>
orshared_ptr<Visitable2>
. -
Therefore, the correct implementation for visitor pattern using smart pointer should be like this one: shared_ptr_visitor.cpp. Running it you would get:
virtual int Visitor1::visit(std::shared_ptr<Visitable1>) virtual int BaseVisitor::visit(std::shared_ptr<Visitable1>) virtual int Visitor1::visit(std::shared_ptr<Visitable2>) virtual int Visitor2::visit(std::shared_ptr<Visitable2>)
-
Here is an example of using the visitor pattern we've got so far in a real program: example_usage_visitor.cpp. Running it you would get:
virtual int RemovePass::visit(std::shared_ptr<BatchOp>) virtual int Pass::visit(std::shared_ptr<BatchOp>) virtual int RemovePass::visit(std::shared_ptr<ShuffleOp>) virtual int PrintPass::visit(std::shared_ptr<ShuffleOp>)
Deep learning framework might need to deal with various operators such as
BatchOp
andShuffleOp
here. We can use this visitor pattern to create different node/tree passes (PrintPass
andRemovePass
here), each pass dealing with certain operators and perform certain action towards these operators. -
Now, there is a lot of duplicate code. How to clean the code a bit? To clean the
Visitor
, we may use CRTP pattern and the code becomes this: wrong_visitor_clean.cpp. Running it you would get:virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>) virtual int BaseVisitor::visit(std::shared_ptr<BaseVisitable>)
The problem is that in
accept(BaseVisitor *v)
, the parameterv
cannot atomically cast from classBaseVisitor
to classVisitor<Visitable1>
orVisitor<Visitable2>
. -
Therefore, the correct cleaned visitor implementation would be like this: visitor_clean.cpp. Running it you would get:
virtual int Visitor1::visit(std::shared_ptr<Visitable1>) virtual int Visitor1::visit(std::shared_ptr<Visitable2>) virtual int Visitor2::visit(std::shared_ptr<Visitable2>)
-
There are still a lot of duplicate code for
Visitable
. Use similar CRTP to also cleanVisitable
code: visitor_visitable_clean.cpp. Running it you would get:virtual int Visitor1::visit(std::shared_ptr<Visitable1>) virtual int Visitor1::visit(std::shared_ptr<Visitable2>) virtual int Visitor2::visit(std::shared_ptr<Visitable2>)
-
Another way to clean
Visitable
code without CRTP would be like this: visitable_clean_without_crtp.cpp. Running it you would get:virtual int Visitor1::visit(std::shared_ptr<Visitable1>) virtual int Visitor1::visit(std::shared_ptr<Visitable2>) virtual int Visitor2::visit(std::shared_ptr<Visitable2>)
The trick here is to create a template function
shared_from
in the Base class, which accepts a Derived class pointer and cast the pointer intostd::shared_ptr<Derived>
.