Implementing iterator blocks in C++ (part 3: LINQ)

Now that we know how to implement coroutines with fibers, we can use them to port to the unmanaged world of C++ a few familiar C# constructs, like iterator blocks and LINQ operators. The result will be in the form of a small template class library. A first code drop of the cppLinq library is available in this source repository on Google code.

The cppLinq library compiles with the Developer Preview of Visual Studio 2011; I have used a few features from the latest C+11 standard that were not well implemented in VS2010. (For example, VS2010 had a few problems with nested lambda expressions). I am also using the new C++ unit test framework to implement the unit tests.

Small disclaimer: this is still a work in progress; I have not yet implemented all the LINQ operators, and the code could certainly use some refactoring and cleaning up. Also, using fibers in a Windows application is a bit dangerous and should be done carefully. In my opinion, this code is mostly useful as a reference of how C# iterators and LINQ is actually works. I doubt I will ever use this code in a real program, but certainly I know better how to use these constructs in C# now!

Iterator blocks

We have seen that In C# an iterator is a function that returns an IEnumerable<T> or an IEnumerator<T>. When the C# compiler compiles such a function, it generates an hidden iterator class that implement the IEnumerator<T> and IEnumerable<T> interfaces  and generates code for a state machine that produces the sequence of values “yielded” by the iterator. (See part 1, or, even better, this article for a detailed description).

We can base the C++ implementation iterator blocks on the equivalent autogenerated C# code, using fibers in place of the state machine. In C++, we can define IEnumerator and IEnumerable as follows:

template <typename T>
class IEnumerator
{
public:
    virtual void Reset() = 0;
    virtual bool MoveNext() = 0;
    virtual T& get_Current() = 0;
};

template<typename T>;
class IEnumerable
{
public:
    virtual std::shared_ptr<IEnumerator<T>> GetEnumerator() = 0;
};

Collecting our garbage

Note that GetEnumerator returns a std::shared_ptr. Managing the lifespan of iterator blocks turns out to be an interesting problem: iterators can be generated inside other iterators and can be composed in a data pipeline, so it is not easy to manually dispose of them. The code of iterator blocks is naturally more elegant when written in languages that support garbage collection. We don’t have GC in C++, but we can use shared pointers; as long as we guarantee that we always use shared pointers to manage the lifetime of our IEnumerables and IEnumerators (and that there are no circular references between them) we can be sure that they will be automatically deleted when the reference count associated to the shared pointer goes to zero.

By the way, I think that the standard implementation of shared_ptr is very cool. It also provides ways to generate a shared_ptr from this and, in case of multiple-inheritance, to convert a shared_ptr <Base1> into a shared_ptr <Base2>.

The IteratorBlock class

Going on with the implementation, we can write a template class IteratorBlock<T> that implements an iterator that produces a sequence of T objects through the IEnumerator and IEnumerable interfaces:

template <typename TSource>
class IteratorBlock : public IEnumerable<TSource>,
                      public IEnumerator<TSource>,
                      public Fiber
{
public:
    // IEnumerable
    virtual std::shared_ptr<IEnumerator<TSource>> GetEnumerator() {
        if (::GetCurrentThreadId() == _threadId && ! _enumeratorCreated) {
            _enumeratorCreated = true;
            return std::dynamic_pointer_cast<IEnumerator<TSource>>(shared_from_this());
        }

        std::shared_ptr<IteratorBlock<TSource>> cloned = clone();
        return cloned->GetEnumerator();
    }

    // IEnumerator
    virtual void Reset() {
        throw InvalidOperationException();
    }

    virtual bool MoveNext() {
        return resume();
    }

    virtual TSource& get_Current() {
        return _current;
    }

    void yieldReturn(TSource returnValue) {
        _current = returnValue;
        yield(true);
    }

    void yieldBreak() {
        yield(false);
    }

protected:
    IteratorBlock() :
        _current(),
        _enumeratorCreated(false),
        _threadId(::GetCurrentThreadId()) {
    }
    virtual ~IteratorBlock() {}

    virtual std::shared_ptr<IteratorBlock<Source>> clone() const = 0;

private:
    TSource _current;
    bool _enumeratorCreated;
    DWORD _threadId;
};

Class IteratorBlock is designed to be used as the base class of an iterator class. Its implementation works as a coroutine, based on the Fiber class from which it inherits.

In order to implement an iterator block, we need to inherit from IteratorBlock<T> and provide an implementation for the abstract methods clone() and run() .

As in C# iterators, the method Reset should never be called and throws an exception.

The method MoveNext simply restarts the coroutine, calling Fiber::resume. From here, the execution goes to the overridden method run. Inside run we can suspend the execution of the coroutine by calling yieldReturn and yieldBreak which behave, respectively, like the C# yield return and yield break keywords. A call to yieldReturn also specifies the current value to be returned by the iterator, stored in the data member _current­.

The implementation of GetEnumerator deserves a few more words. An iterator block is an IEnumerable, from which we can ask for one or more instances of an IEnumerator. The first time we call GetEnumerator the function can return a pointer to the object itself, since it implements both interfaces, but following calls to GetEnumerator return a pointer to a copy of the iterator block object. The abstract method clone needs to be implemented to create this copy by cloning an instance of the iterator class.

Example: a Fibonacci generator

As example we can implement a generator for the Fibonacci sequence writing a class like the following, putting all the logic in the overridden run function.

class FibonacciIterator : public IteratorBlock<long>
{
public:
    virtual void run() {
        long a = 0;
        long b = 1;
        yieldReturn(a);
        yieldReturn(b);
        for (int i = 2; i < 10; i++) {
            long tmp = a + b;
            a = b;
            b = tmp;
            yieldReturn(b);
        }
    }

    virtual std::shared_ptr<IteratorBlock<long>> clone() const {
        return this;
    }
};

The foreach loop

Client code can interact with an iterator with code like this:

auto source = std::shared_ptr<IEnumerable<long>>(new FibonacciIterator());
auto e = source->GetEnumerator();
while (e->MoveNext())
{
    long value = e->get_Current();
    // use ‘value’
}

This is the equivalent of a foreach statement, and can be encapsulated in a foreach function, templatized on the type of a function that will be applied to each item in the sequence:

template <typename T, typename Func>
static void foreach(std::shared_ptr<IEnumerable<T>> enumerable, Func f)
{
    std::shared_ptr<IEnumerator<T>> e = enumerable->GetEnumerator();
    while (e->MoveNext()) {
        f(e->get_Current());
    }
}

This allows us to write code like the following:

std::vector<std::string> greek;
greek.push_back("alpha"); // initializer lists not supported in VS11
greek.push_back("beta");
greek.push_back("gamma");

auto it = new StlEnumerable<std::vector<std::string>, std::string>(v);

foreach<int>(it, [](std::string& s) -> void {
    std::cout << s << std::endl;
});

Here we assume to have an iterator class StlEnumerable<Container, T>, which generates a sequence iterating over the items of a STL container of T’s. (You can find this class in the sources).

The code simply creates an iterator from a STL vector, and then applies a lambda expression to each item in the vector, with a foreach loop. Nothing very useful, but we are now almost ready to extend this with composition and LINQ-like operators.

Closures

A class like the FibonacciIterator is a closure; it encapsulates some behavior (the function run) and also the context (the data members) on which the function operates.

Writing code like this (that is, a class that derives from IteratorBlock and overrides the method run) works well, but can be a little “verbose” since it forces us to define a different class for each different iterator. In C++11 we can instead implement our coroutine with a lambda expression, which will automatically implement a closure for us.

The following class, _IteratorBlock, inherits from the previous ­IteratorBlock class and adds the capability of passing a lambda expression to the constructor, which will be executed by the method run.

template <typename TSource>
class _IteratorBlock : public IteratorBlock<TSource>
{
protected:
    struct IF {
        virtual void run(IteratorBlock<TSource>* pThis) = 0;
    };

    template <typename Func>
    struct F : public IF
    {
        F(Func func) : _func(func) {
        }

        virtual void run(IteratorBlock<TSource&>* pThis) {
            _func(pThis);
        }

        Func _func;
    };

public:
    template <typename _F>
    _IteratorBlock(_F f) :
        _f(std::shared_ptr<IF>(new F<_F>(f))) {
    }

    _IteratorBlock(const _IteratorBlock& rhs) :
        _f(rhs._f) {
    }

protected:
    virtual void run() {
        _f->run(this);
    }

    virtual std::shared_ptr<IteratorBlock<TSource>> clone() const {
        return std::shared_ptr<IteratorBlock<TSource>>(new _IteratorBlock<TSource>(*this));
    }

private:
    std::shared_ptr<IF> _f;
};

Using this specialized ­_IteratorBlock class we can simply create an iterator defining a lambda. (Interestingly, it will be the compiler to create a closure class for us, storing captured variables in data members, but we don’t need to worry about this implementation detail).

So, we can more simply define our Fibonacci iterator like follows:

auto fn = [](IteratorBlock<long>* it) {
    long a = 0;
    long b = 1;
    it->yieldReturn(a);
    it->yieldReturn(b);
    for (int i = 2; i < 10; i++) {
        long tmp = a + b;
        a = b;
        b = tmp;
        it->yieldReturn(b);
    }
}
return std::shared_ptr<IEnumerable<T>>(new _IteratorBlock<T>(fn));

Reimplementing LINQ to Objects

Now we have finally all the pieces ready to reimplement LINQ to objects.

How to proceed? I could have tried to reimplement all the LINQ clauses/methods myself. I could have used Reflector to find out how they are actually implemented in System.Linq.dll. But since I am very lazy, I decided to just “reuse” the work of Jon Skeet, who wrote a long and interesting series of articles about reimplementing the whole of LINQ to Objects in C#, some time ago.

What is better, he also wrote an exhaustive test suite for his reimplementation. All I had to do was to convert his code and tests from C# to C++… J

Of course, C++ does not have extensions methods, so I just added new methods to the IEnumerable<T> interface. For example, this is the code for the Where LINQ clause:

template <typename T>
class IEnumerable : public std::enable_shared_from_this<IEnumerable<T>>
{
public:
    virtual std::shared_ptr<IEnumerator<T>> GetEnumerator() = 0;

    // Linq operators
    …
    template <typename Predicate>
    std::shared_ptr<IEnumerable<T>> Where(Predicate predicate) {
        if (nullptr == predicate) {
            throw ArgumentNullException();
        }

        // deferred
        std::shared_ptr<IEnumerable<T>> source = shared_from_this();
        auto fn =  [source, predicate](IteratorBlock<T>* it) {
            foreach<T>(source, [it, predicate](T& item) {
                if (predicate(item)) {
                    it->yieldReturn(item);
                }
            });
        };
        return std::shared_ptr<IEnumerable<T>>(new _IteratorBlock<T>(fn));
    }
};

LINQ methods use deferred execution – until a client start trying to fetch items from the output sequence, they won’t start fetching items from the input sequence. Consequently, a method like Where is divided in two parts. Argument validation can be executed immediately, and the actual iterative code gets executed later, “on demand”.

The constructor of the Where-iterator class is templatized on the type of the predicate function passed as argument. This allows to pass as predicate either a lambda expression, or a std::function, or also any “functor” class that implements operator (). The implementation of Where is simply a lambda expression that is passed to an instance of the _IteratorBlock class. The lambda keeps fetching items from its source enumerator until the sequence is over or until an item is found that satisfies the predicate. In the latter case, the item is yielded to the caller.

Again, LINQ operators are lazily evaluated: expressions are evaluated only when their value is effectively required.

Data pipelines

What makes LINQ operators particularly useful is the ability of compose them into data pipelines, using the output of an iterator as the source of the next iterator in a pipe.

For example, the code in the following (quite contrived) examples prints the square of all even integers smaller than 10:

auto source = IEnumerable<int>::Range(0, 10);

auto it = source->Where([](int val) { return ((val % 2) == 0); })
                ->Select<double>([](int val) -> double { return (val * val); }));

foreach<double>(it, [](double& val){
    printf("%.2f\n", val);
});

Unit Tests

Unit tests are written using the new “CppUnitTest” framework that ships with the preview of VS11. As said, the tests are based on (and with “based on” I mean copied from J) the unit tests that Jon Skeet wrote for his “EduLinq” blog series.

For example, this is one of the unit tests for the Where method:

TEST_CLASS(WhereTest)
{
public:
    ...
    TEST_METHOD(WhereTest_SimpleFiltering2)
    {
        int v[] = { 1, 3, 4, 2, 8, 1 };
        std::shared_ptr<IEnumerable<int>> source(new Vector<int>(v, ARRAYSIZE(v)));
        std::function<bool(const int&)> predicate = [](int x) { return x < 4; };

        auto result = source->Where(predicate);

        int expected[] = { 1, 3, 2, 1 };
        Assert::IsTrue(result->SequenceEqual(expected, ARRAYSIZE(expected)));
    }
    ...
};

Finally…

This ends my quick excursus in the world on unmanaged iterator blocks. The sources of the small cppLinq library can be found here.  A few final comments:

  • As I said, this is still a work in progress. There are a few LINQ methods still left to implement (mainly the OrderBy method, which is the most laborious to write), but I hope to be able to complete them soon.
  • Basing the implementation of Win32 fibers has many disadvantages, but also one advantage: in C# iterators are implemented as state-machines and it is not possible to yield from more than one stack frame deep. Here, we don’t have that limitation.
  • It is very interesting to play with some of the features of the new C++ standard. Used together, templates and lambdas are a very powerful tool and allow us to make the C++ code very similar to its C# equivalent. But I found that this kind of C++ code is a bit too complicate to write correctly and when something is wrong the resulting error messages are not always easy to “decrypt”.
  • I learned a lot about the actual LINQ  writing this.

Implementing iterator blocks in C++ (part 2: Win32 Fibers)

Fibers were added to Windows NT to support cooperative multitasking. They can be thought as lightweight threads that must be manually scheduled by the application.

When a fiber is created, it is passed a fiber-start function. The OS then assigns it a separate stack and sets up execution to begin at this fiber-start function. To schedule this fiber, you need to switch to it manually and when running, a fiber can then suspend itself by yielding execution to another fiber, or back to “calling” fiber. In other words, fibers are a perfect tool to implement coroutines sequencing.

These two articles, from Dr.Dobb’s and MSDN Magazine, explain how to implement coroutines using the Fiber API. The MSDN Magazine article also shows how to do this in .NET (it was written before the release of .NET 2.0, so before iterators were available in C#. But actually Fibers don’t get along well with the CLR). Together with an old series of Raymond Chen’s blog posts, these articles have been the main source of inspiration for this small project. (In this article Duffy proposes another possible implementations of coroutines based on threads, which is however quite inefficient).

Interestingly, a Fiber-based implementation does not have the limitations of (state machine based) C# iterators; with fibers it is possible to yield the control from any function in a stack frame. There are other problems, though.

Fibers are like dynamite

The Win32 Fiber API is quite simple. The first thing to do is to convert the thread on which the fibers will run into a fiber, by calling ConvertThreadToFiber. After this, additional fibers are created using CreateFiber passing as parameter the address of the function to execute, just as the threadProc for real threads. Then, a fiber can suspend itself and start the execution of another fiber by calling SwitchToFiber. Finally, when the application has done using fibers, it can convert the “main” fiber back to a normal thread with ConvertFiberToThread.

There are a few important caveats to consider:

  • It is difficult to write a library or framework that uses fibers, because the entire program must be designed to support them.  For example, the function ConvertThreadToFiber must be called only once in a thread.
  • Since each fiber has its own stack, it also has its own SEH exception chain. This means that if a fiber throws an exception, only that fiber can catch it. The same is true for C++ exceptions. Exceptions cannot pass across fibers’ “boundaries”.
  • The default stack size is 1MB, so using many fibers can consume a lot of memory.
  • The code must be “fiber-safe”, but most code is designed to be just “thread-safe”. For example, using thread-local storage does not work with fibers, and in fact Windows Vista introduced an API for Fiber local storage. More importantly, the CRT was not completely fiber-safe in the past, and I am not sure it is now. There are also compiler options to set in Visual Studio, like /GT, which enables only Fiber-safe code optimizations.

In other words, as Raymond Chen put it, “Fibers are like dynamite. Mishandle them and your process explodes”.  Therefore, this framework for iterators and Linq-like operators we’ll define should be used VERY CAREFULLY in real programs.
(And if you think things are bad in C++, consider that fibers are practically unusable with .NET managed code!)

Implementation details

Putting all these caveats aside, let’s see how coroutines can be implemented with fibers. This is the declaration of a Fiber class I created as a thin wrapper over the Win32 Fiber API.

class Fiber
{
public:
    Fiber();
    virtual ~Fiber();

    static void enableFibersInCurrentThread();
    static bool disableFibersInCurrentThread();

    void* main();
    void* resume();

protected:
    virtual void run() = 0;
    void yield(bool goOn);

private:
    static void WINAPI fiberProc(void* lpFiberParameter);

    PFIBER_START_ROUTINE _fiber;
    PFIBER_START_ROUTINE _previousFiber;

    enum FiberState {
        FiberCreated, FiberRunning, FiberStopPending, FiberStopped
    };
    FiberState _state;
};

The Fiber class exposes the static functions enableFibersInCurrentThread and disableFibersInCurrentThread that wrap the initialization/termination functions.
The constructor creates a new fiber object, specifying Fiber::fiberProc as the function to execute.

Fiber::Fiber() :
    _state(FiberCreated),
    _previousFiber(nullptr),
    _exception(),
    _exceptionCaught(false)
{
    _fiber = (PFIBER_START_ROUTINE)::CreateFiber(256 * 1024, fiberProc, this);
}

As said, fibers use cooperative multitasking: the execution of a fiber must be explicitly scheduled by the application by calling SwitchToFiber. This is encapsulated by the method resume below:

bool Fiber::resume()
{
    if (nullptr == _fiber || _state == FiberStopped) {
        return false;
    }

    _previousFiber = (PFIBER_START_ROUTINE)::GetCurrentFiber();
    assert(_previousFiber != _fiber);

    ::SwitchToFiber(_fiber);

    if (_exceptionCaught) {
        throw _exception;
    }

    return (FiberRunning == _state);
}

When the fiber is started with SwitchToFiber, it begins executing from the fiberProc method, which calls main:

void CALLBACK Fiber::fiberProc(void* pObj)
{
    Fiber* pFiber = (Fiber*)pObj;
    void* previousFiber = pFiber-&gt;main();
    ::SwitchToFiber(previousFiber);
}

What main does is simply to call the abstract function run, which any class derived from Fiber needs to implement.
Inside run we can at some point yield the execution to a different Fiber object, calling its resume method, so effectively creating a “stack” of fibers nested into each other. Or, we can yield the execution to the previously running fiber (the one that launched the current one), calling the method yield (equivalent to yield return and yield break in C#):

void Fiber::yield(bool goOn)
{
    if (! goOn) {
        _state = FiberStopped; // yield break
    }
    ::SwitchToFiber(_previousFiber);
}

Logically this works like the sequence of nested function calls in a thread stack, but of course there is no physical stack here, and we can return back to the caller only by storing a pointer to the previous fiber and explicitly yielding control to it.

Since exceptions cannot travel outside a fiber, the call to run is wrapped in a try-catch clause which tries to catch any kind of exceptions (even Win32 exceptions). If caught, data about an exception is stored, the fiber is stopped and the fiberProc ends restarting the execution of the previous fiber, inside the resume function. Here exceptions can be re-created and re-thrown in the context of the previous fiber, so effectively forwarding them back, up on the “fiber stack”. Not sure this is a very elegant workaround, but I could not find a better solution.

void* Fiber::main()
{
    _state = FiberRunning;
    _exceptionCaught = false;

    try {
        run();
    }
    catch (StopFiberException&amp;)
    {
        _state = FiberStopped;
    }
    catch (std::exception&amp; e)
    {
        _exception = e;
        _exceptionCaught = true;
        _state = FiberStopped;
    }
    catch (...)
    {
        _exception = std::exception("win32 exception");
        _exceptionCaught = true;
        _state = FiberStopped;
    }

    _state = FiberStopped;
    return _previousFiber;
}

Finally, when we have done with a fiber object, we can delete it, so releasing all the resources it had allocated and the memory of its stack.

Fiber::~Fiber()
{
    if (_state == FiberRunning) {
        _previousFiber = (PFIBER_START_ROUTINE)::GetCurrentFiber();
        _state = FiberStopPending;
        ::SwitchToFiber(_fiber);
    }
    ::DeleteFiber(_fiber);
}

Next…

The Fiber class is the main building block in the implementation of coroutines. What is left to do is just to write classes that inherit from Fiber and implement in the virtual function run the actual coroutines code to execute.
In the next posts we’ll see how to use the Fiber class to implement iterator blocks that work like in C#, implementing the IEnumerable and IEnumerator interfaces, and how to use this to reimplement many Linq operators.

If you are curious, you can look at the sources of this small project. A first code drop (still quite a preliminary version) is available in this source repository on Google code.
Let me know your comments and suggestions, if you are so kind to have a look.