Generator functions in C++

In the previous post we had a look at the proposal of introducing resumable functions into the C++ standard to support writing asynchronous code modeled on the C# async/await pattern.

We saw that it is already possible to experiment with the future resumable and await keywords in Visual Studio, by installing the latest November 2013 CTP. But the concept of resumable functions is not limited to asynchrony; in this post we’ll see how it can be expanded to support generator functions.

Generator functions and lazy evaluation

In several languages, like C# and Python, generator functions provide the ability of lazily producing the values in a sequence only when they are needed. In C# a generator (or iterator) is a method that contains at least one yield statement and that returns an IEnumerable<T>.

For example, the following C# code produces the sequence of Fibonacci numbers (1, 1, 2, 3, 5, 8, 13, 21, …):

IEnumerable<T> Fibonacci()
{
    int a = 0;
    int b = 1;
    while (true) {
        yield return b;
        int tmp = a + b;
        a = b;
        b = tmp;
    }
}

A generator acts in two phases. When it is called, it just sets up a resumable function, preparing for its execution, and returns some enumerator (in the case of C#, an IEnumerable<T>). But the actual execution is deferred to the moment when the values are actually enumerated and pulled from the sequence, for example with a foreach statement:

foreach (var num in Fibonacci())
{
    Console.WriteLine("{0}", num);
}

Note that the returned sequence is potentially infinite; its enumeration could go on indefinitely (if we ignore the integer overflows).

Of course there is nothing particularly special about doing the same thing in C++. While STL collections are usually eagerly evaluated (all their values are produced upfront) it is not difficult to write a collection that provides iterators that calculate their current value on the spot, on the base of some state or heuristic.

What gives a particular expressive power to generators is the ability to pause the execution each time a new value is generated, yielding control back to the caller, and then to resume the execution exactly from the point where it had suspended. A generator is therefore a special form of coroutine, limited in the sense that it may only yield back to its caller.

The yield statement hides all the complexity inherent in the suspension and resumption of the function; the developer can express the logic of the sequence plainly, without having to setup callbacks or continuations.

From resumable functions to generators (and beyond)

It would be nice to bring the expressive power of generators to our good old C++, and naturally there is already some work going on for this. In this proposal Gustaffson et al. explain how generator functions could be supported by the language as an extension of resumable functions, making it possible to write code like:

sequence<int> fibonacci() resumable
{
    int a = 0;
    int b = 1;
    while (true)
    {
        yield b;
        int tmp = a + b;
        a = b;
        b = tmp;
    }
}

Here, the proposal introduces two new concepts, the type sequence<T> and the yield keyword.

–        sequence<T> is a (STL-like) collection that only supports iteration and only provides an input iterator.

–        The yield statement suspends the execution of the function and returns one item from the sequence to the caller.

In C# terms, sequence<T> and its iterator are respectively the equivalent of an IEnumerable<T> and IEnumerator<T>. But while the C# generators are implemented with a state machine, in C++ the suspension and resumption would be implemented, as we’ll see, with stackful coroutines.

Once we had a lazy-evaluated sequence<T> we could write client code to pull a sequence of values, which would be generated one at the time, and only when requested:

sequence<int> fibs = fibonacci();
for (auto it = fibs.begin(); it != fibs.end(); ++it)
{
    std::cout << *it << std::endl;
}

In C++11 we could also simplify the iteration with a range-based for loop:

sequence<int> fibs = fibonacci();
for (auto it : fibs)
{
    std::cout << *it << std::endl;
}

More interestingly, we could define other resumable functions that manipulate the elements of a sequence, lazily producing another sequence. This example, taken from Gustaffson’s proposal, shows a lazy version of std::transform():

template<typename Iter>
sequence<int> lazy_tform(Iter beg, Iter end, std::function<int(int)> func) resumable
{
    for (auto iter = beg; iter != end; ++iter)
    {
        yield func(*iter);
    }
}

Moving further with this idea, we could pull another page out of the C# playbook and enrich the sequence class with a whole set of composable, deferred query operators, a la LINQ:

template <typename T>
class sequence
{
public:
    template <typename Predicate> bool all(Predicate predicate);
    [...]
    static sequence<int> range(int from, int to);
    template <typename TResult> sequence<TResult> select(std::function<TResult(T)> selector);
    sequence<T> take(int count);
    sequence<T> where(std::function<bool(T)> predicate);
};

Lazy sequences

Certainly, resumable generators would be a very interesting addition to the standard. But how would they work? We saw that the Visual Studio CTP comes with a first implementation of resumable functions built over the PPL task library, but in this case the CTP is of little help, since it does not support generator functions yet. Maybe they will be part of a future release… but why to wait? We can implement them ourselves! 🙂

In the rest of this post I’ll describe a possible simple implementation of C++ lazy generators.

Let’s begin with the lazy sequence<T> class. This is a STL-like collection which only needs to support input iterators, with a begin() and an end() method.

Every instance of this class must somehow be initialized with a functor that represents the generator function that will generate the values of the sequence. We’ll see later what can be a good prototype for it.

As we said, the evaluation of this function must be deferred to the moment when the values are retrieved, one by one, via the iterator. All the logic for executing, suspending and resuming the generator will actually be implemented by the iterator class, which therefore needs to have a reference to the same functor.

So, our first cut at the sequence class could be something like this:

template<typename T>
class sequence_iterator
{
    // TO DO
};
template<typename T>
class sequence
{
public:
    typedef typename sequence_iterator<T> iterator;
    typedef ??? functor;

    sequence(functor func) : _func(func) { }
    iterator begin() {
        return iterator(_func);
    }
    iterator end() {
        return iterator();
    }

private:
    functor _func;
};

Step by step

The sequence<T> class should not do much more than create iterators. The interesting code is all in the sequence iterator, which is the object that has the ability to actually generate the values.

Let’s go back to our Fibonacci generator and write some code that iterates through it:

sequence<int> fibonacci() resumable
{
    int a = 0;
    int b = 1;
    while (true)
    {
        yield b;
        int tmp = a + b;
        a = b;
        b = tmp;
    }
}

auto fibs = fibonacci();
for (auto it : fibs)
{
    std::cout << *it << std::endl;
}

How should this code really work? Let’s follow its execution step by step.

  1. First, we call the function fibonacci(), which returns an object of type sequence<int>. Note that at this point the execution of the function has not even started yet. We just need to return a sequence object somehow associated to the body of the generator, which will be executed later.
  2. The returned sequence is copied into the variable fibs. We need to define what does it mean to copy a sequence: should we allow copy operations? Should we enforce move semantic?
  3. Given the sequence fibs, we call the begin() method which returns an iterator “pointing ” to the first element of the sequence. The resumable function should start running the moment the iterator is created and execute until a first value is yielded (or until it completes, in case of empty sequences).
  4. When the end() method is called, the sequence returns an iterator that represents the fact that the generator has completed and there are no more values to enumerate.
  5. The operator == () should behave as expected, returning true if both iterators are at the same position of the same sequence, or both pointing at the end of the sequence.
  6. The operator *() will return the value generated by the last yield statement (i.e., the current value of the sequence).
  7. At each step of the iteration, when operator ++() is called, the execution of the generator function will be resumed, and will continue until either the next yield statement updates the current value or until the function returns.

Putting all together, we can begin to write some code for the sequence_iterator class:

template<typename T>
class sequence_iterator
{
public:
    typedef ??? functor;

    sequence_iterator(functor func) {
        // initializes the iterator from the generator functors, executes the functors
        // until it terminates or yields.
    }
    sequence_iterator() : _func(func) {
        // must represent the end of the sequence
    }
    bool operator == (const sequence_iterator& rhs) {
        // true if the iterators are at the same position.
    }
    bool operator != (const sequence_iterator& rhs) {
        return !(*this==rhs);
    }
    const T& operator * () const {
        return _currentVal;
    }
    sequence_iterator operator ++ () {
        // resume execution
        return *this;
    }

private:
    T _currentVal;
};

The behavior of the iterator is fairly straightforward, but there are a few interesting things to note. The first is that evidently a generator function does not do what it says: looking at the code of the fibonacci() function there is no statement that actually returns a sequence<T>; what the code does is simply to yield the sequence elements, one at the time.

So who creates the sequence<T> object? Clearly, the implementation of generators cannot be purely library-based. We can put in a library the code for the sequence<T> and for its iterators, we can also put in a library the platform-dependent code that manages the suspension and resumptions of generators. But it will be up to the compiler to generate the appropriate code that creates a sequence<T> object for a generator function. More on this later.

Also, we should note that there is no asynchrony or concurrency involved in this process. The function could resume in the same thread where it suspended.

Generators as coroutines

The next step is to implement the logic to seamlessly pause and resume a generator. A generator can be seen as an asymmetric coroutine, where the asymmetry lies in the fact that the control can be only yielded back to the caller, contrary to the case of symmetric coroutines that can yield control to any other coroutine at any time.

Unfortunately coroutines cannot be implemented in a platform-independent way. In Windows we can use Win32 Fibers (as I described in this very old post) while on POSIX, you can use the makecontext()/swapcontext() API. There is also a very nice Boost library that we could leverage for this purpose.

But let’s ignore the problems of portability, for the moment, and assume that we have a reliable way to implement coroutines. How should we use them in an iterator? We can encapsulate the non-portable code in a class __resumable_func that exposes this interface:

template <typename TRet>
class __resumable_func
{
    typedef std::function<void(__resumable_func&)> TFunc;

public:
    __resumable_func(TFunc func);

    void yieldReturn(const TRet& value);
    void yieldBreak();
    void resume();

    const TRet& getCurrent() const;
    bool isEos() const;
}

The class is templatized on the type of the values produced by the generator and provides methods to yield one value (yieldReturn()), to retrieve the current value (i.e., the latest value yielded) and to resume the execution and move to the next value.

It should also provide methods to terminate the enumeration (yieldBreak()) and to tell if we have arrived at the end of the sequence (isEos()).

The function object passed to the constructor represents the generator function itself that we want to run. More precisely, it is the function that will be executed as a coroutine, and its prototype tells us that this function, in order to be able to suspend execution, needs a reference to the __resumable_func object that is running the coroutine itself.

In fact the compiler should transform the code of a generator into the (almost identical) code of a lambda that uses the __resumable_func object to yield control and emit a new value.

For example, going back again to our fibonacci() generator, we could expect the C++ compiler to transform the code we wrote:

sequence<int> fibonacci() resumable
{
    int a = 0;
    int b = 1;
    while (true)
    {
        yield b;
        int tmp = a + b;
        a = b;
        b = tmp;
    }
}

into this lambda expression:

auto __fibonacci_func([](__resumable_func<int>& resFn) {
    int a = 0;
    int b = 1;
    while (true)
    {
        resFn.yieldReturn(b);
        int tmp = a + b;
        a = b;
        b = tmp;
    }
});

where the yield statement has been transformed into a call to __resumable_func::yieldReturn().

Likewise, client code that invokes this function, like:

sequence<int> fibs = fibonacci();

should be transformed by the compiler into a call to the sequence constructor, passing this lambda as argument:

sequence<int> fibs(__fibonacci_func);

Sequence iterators

We can ignore the details of the implementation of __resumable_func<T> coroutines for the moment and, assuming that we have them working, we can now complete the implementation of the sequence_iterator class:

template <typename T>
class sequence_iterator
{
    std::unique_ptr<__resumable_func<T>> _resumableFunc;

    sequence_iterator() :
        _resumableFunc(nullptr)
    {
    }

    sequence_iterator(const std::function<void(__resumable_func<T>&)> func) :
        _resumableFunc(new __resumable_func<T>(func))
    {
    }

    sequence_iterator(const sequence_iterator& rhs) = delete;
    sequence_iterator& operator = (const sequence_iterator& rhs) = delete;
    sequence_iterator& operator = (sequence_iterator&& rhs) = delete;

public:
    sequence_iterator(sequence_iterator&& rhs) :
        _resumableFunc(std::move(rhs._resumableFunc))
    {
    }

    sequence_iterator& operator++()
    {
        _ASSERT(_resumableFunc != nullptr);
        _resumableFunc->resume();
        return *this;
    }

    bool operator==(const sequence_iterator& _Right) const
    {
        if (_resumableFunc == _Right._resumableFunc) {
            return true;
        }

        if (_resumableFunc == nullptr) {
            return _Right._resumableFunc->isEos();
        }

        if (_Right._resumableFunc == nullptr) {
            return _resumableFunc->isEos();
        }

        return (_resumableFunc->isEos() == _Right._resumableFunc->isEos());
    }

    bool operator!=(const sequence_iterator& _Right) const
    {
        return (!(*this == _Right));
    }

    const T& operator*() const
    {
        _ASSERT(_resumableFunc != nullptr);
        return (_resumableFunc->getCurrent());
    }
};

The logic here is very simple. Internally, a sequence_iterator contains a __resumable_func object, to run the generator as a coroutine. The default constructor creates an iterator that points at the end of the sequence. Another constructor accepts as argument the generator function that we want to run and starts executing it in a coroutine and the function will run until either it yields a value or terminates, giving the control back to the constructor. In this way we create an iterator that points at the beginning of the sequence.

If a value was yielded, we can call the dereference-operator to retrieve it from the __resumable_func object. If the function terminated, instead, the iterator will already point at the end of the sequence. The equality operator takes care of equating an iterator whose function has terminated to the end()-iterators created with the default constructor. Incrementing the iterator means resuming the execution of the coroutine, from the point it had suspended, giving it the opportunity to produce another value.

Note that, since the class owns the coroutine object, we disable copy constructors and assignment operators and only declare the move constructor, to pass the ownership of the coroutine.

Composable sequence operators

Almost there! We have completed our design, but there are still a few details to work out. The most interesting are related to the lifetime and copyability of sequence objects. What should happen with code like this?

sequence<int> fibs1 = fibonacci();
sequence<int> fibs2 = fibs1;
for (auto it1 : fibs1) {
    for (auto it2 : fibs2) {
        ...
    }
}

If we look at how we defined class sequence<T>, apparently there is no reason why we should prevent the copy of sequence objects. In fact, sequence<T> is an immutable class. Its only data member is the std::function object that wraps the functor we want to run.

However, even though we don’t modify this functor object, we do execute it. This object could have been constructed from a lambda expression that captured some variables, either by value or by reference. Since one of the captured variables could be a reference to the same sequence<T> object that created that iterator, we need to ensure that the sequence object will always outlive its functors, and allowing copy-semantics suddenly becomes complicated.

This brings us to LINQ and to the composability of sequences. Anyone who has worked with C# knows that what makes enumerable types truly powerful and elegant is the ability to apply chains of simple operators that transform the elements of a sequence into another sequence. LINQ to Objects is built on the concept of a data pipeline: we start with a data source which implements IEnumerable<T>, and we can compose together a number of query operators, defined as extension methods to the Enumerable class.

For example, this very, very useless query in C# generates the sequence of all square roots of odd integers between 0 and 10:

var result = Enumerable.Range(0, 10)
    .Where(n => n%2 == 1)
    .Select(n => Math.Sqrt(n));

Similarly, to make the C++ sequence<T> type really powerful we should make it composable and enrich it with a good range of LINQ-like operators to generate, filter, aggregate, group, sort and generally transform sequences.

These are just a few of the operators that we could define in the sequence<T> class:

template <typename T>
class sequence
{
public:
    [...]
    static sequence<int> range(int from, int to);
    template <typename TResult> sequence<TResult> select(std::function<TResult(T)> selector);
    sequence<T> where(std::function<bool(T)> predicate);
};

to finally be able to write the same (useless) query:

sequence<double> result = sequence<int>::range(0, 10)
    .where([](int n) { return n => n%2 == 1; })
    .select([](int n) { return sqrt(n); });

Let’s try to implement select(), as an experiment. It is conceptually identical to the lazy_tform() method  we saw before, but now defined in the sequence class. A very naïve implementation could be as follows:

// Projects each element of a sequence into a new form. (NOT WORKING!)
template <typename TResult>
sequence<TResult> select(std::function<TResult(T)> selector)
{
    auto func = [this, selector](__resumable_func<T>& rf) {
        for (T t : *this)
        {
            auto val = selector(t);
            rf.yieldReturn(val);
        }
    };
    return sequence<TResult>(func);
}

It should be now clear how it works: first we create a generator functor, in this case with a lambda expression, and then we return a new sequence constructed on this functor. The point is that the lambda needs to capture the “parent” sequence object to be able to iterate through the values of its sequence.

Unfortunately this code is very brittle. What happens when we compose more operators, using the result of one as the input of the next one in the chain? When we write:

sequence<double> result = sequence<int>::range(0, 10)
    .where([](int n) { return n => n%2 == 1; })
    .select([](int n) { return sqrt(n); });

there are (at least) three temporary objects created here, of type sequence<T>, and their lifetime is tied to that of the expression, so they are deleted before the whole statement completes.

A chain of sequences

The situation is like in the figure: the functor of each sequence in the chain is a lambda that has captured a pointer to the previous sequence object. The problem is in the deferred execution: nothing really happens until we enumerate the resulting sequence through its iterator, but as soon as we do so each sequence starts pulling values from its predecessor, which has already been deleted.

Temporary objects and deferred execution really do not get along nicely at all. On one hand in order to compose sequences we have to deal with temporaries that can be captured in a closure and then deleted long before being used. On the other hand, the sequence iterators, and their underlying coroutines, should not be copied and can outlive the instance of the sequence that generated them.
We can enforce move semantics on the sequence<T> class, but then what do we capture in a generator like select() that acts on a sequence?

As often happens, a possible solution requires adding another level of indirection. We introduce a new class, sequence_impl<T>, which represents a particular application of a generator function closure:

template <typename T>
class sequence_impl
{
public:
    typedef std::function<void(__resumable_func<T>&)> functor;

private:
    const functor _func;

    sequence_impl(const sequence_impl& rhs) = delete;
    sequence_impl(sequence_impl&& rhs) = delete;
    sequence_impl& operator = (const sequence_impl& rhs) = delete;
    sequence_impl& operator = (sequence_impl&& rhs) = delete;

public:
    sequence_impl(const functor func) : _func(std::move(func)) {}

    sequence_iterator<T> begin() const
    {
        // return iterator for beginning of sequence
        return iterator(_func);
    }
    sequence_iterator<T> end() const
    {
        // return iterator for end of sequence
        return iterator();
    }
};

A sequence_impl<T> is neither copiable nor movable and only provides methods to iterate through it.

The sequence<T> class now keeps only a shared pointer to the unique instance of a sequence_impl<T> that represents that particular application of the generator function. Now we can support chained sequences by allowing move semantics on the sequence<T> class.

template <typename T>
class sequence
{
    std::shared_ptr<sequence_impl<T>> _impl;

    sequence(const sequence& rhs) = delete;
    sequence& operator = (const sequence& rhs) = delete;

public:
    typedef typename sequence_impl<T>::iterator iterator;
    typedef typename sequence_impl<T>::functor functor;

    sequence(functor func) {
        _impl(std::make_shared<sequence_impl<T>>(func))
    }
    sequence(sequence&& rhs) {
        _impl = std::move(rhs._impl);
    }
    sequence& operator = (sequence&& rhs) {
        _impl = std::move(rhs._impl);
    }

    iterator begin() const {
        return _impl->begin();
    }
    iterator end() const {
        return _impl->end();
    }
};

The diagram below illustrates the relationships between the classes involved in the implementation of lazy sequences:

genFunc2

LINQ operators

Ok, now we have really almost done. The only thing left to do, if we want, is to write a few sequence-manipulation operators, modeled on the example of the LINQ-to-objects. I’ll list just a few, as example:

// Determines whether all elements of a sequence satisfy a condition.
bool all(std::function<bool(T)> predicate)
{
    if (nullptr == predicate) {
        throw std::exception();
    }

    for (auto t : *_impl)
    {
        if (!predicate(t)) {
            return false;
        }
    }
    return true;
}

// Returns an empty sequence
static sequence<T> empty()
{
    auto fn = [](__resumable_func<T>& rf) {
        rf.yieldBreak();
    };
    return sequence<T>(fn);
}

// Generates a sequence of integral numbers within a specified range [from, to).
static sequence<int> range(int from, int to)
{
    if (to < from) {
        throw std::exception();
    }

    auto fn = [from, to](__resumable_func<T>& rf) {
        for (int i = from; i < to; i++) {
            rf.yieldReturn(i);
        }
    };
    return sequence<int>(fn);
}

// Projects each element of a sequence into a new form.
template <typename TResult>
sequence<TResult> select(std::function<TResult(T)> selector)
{
    if (nullptr == selector) {
        throw std::exception();
    }

    std::shared_ptr<sequence_impl<T>> impl = _impl;
    auto fn = [impl, selector](__resumable_func<T>& rf) {
        for (T t : *impl)
        {
            auto val = selector(t);
            rf.yieldReturn(val);
        }
    };
    return sequence<TResult>(fn);
}

// Returns a specified number of contiguous elements from the start of a sequence.
sequence<T> take(int count)
{
    if (count < 0) {
        throw std::exception();
    }

    std::shared_ptr<sequence_impl<T>> impl = _impl;
    auto fn = [impl, count](__resumable_func<T>& rf) {
        auto it = impl->begin();
        for (int i = 0; i < count && it != impl->end(); i++, ++it) {
            rf.yieldReturn(*it);
        }
    };
    return sequence<T>(fn);
}

// Filters a sequence of values based on a predicate.
sequence<T> where(std::function<bool(T)> predicate)
{
    if (nullptr == predicate) {
        throw std::exception();
    }

    std::shared_ptr<sequence_impl<T>> impl = _impl;
    auto fn = [impl, predicate](__resumable_func<T>& rf) {
        for (auto item : *impl)
        {
            if (predicate(item)) {
                rf.yieldReturn(item);
            }
        }
    };
    return sequence<T>(fn);
}

We could write many more, but I think these should convey the idea.

Example: a prime numbers generator

As a final example, the following query lazily provides the sequence of prime numbers (smaller than INT_MAX), using a very simple, brute-force algorithm. It is definitely not the fastest generator of prime numbers, it’s maybe a little cryptic, but it’s undoubtedly quite compact!

sequence<int> primes(int max)
{
    return sequence<int>::range(2, max)
        .where([](int i) {
            return sequence<int>::range(2, (int)sqrt(i) + 2)
                .all([i](int j) { return (i % j) != 0; });
            });
}

Conclusion

In this article I rambled about generators in C++, describing a new sequence<T> type that model lazy enumerators and that could be implemented as an extension of resumable functions, as specified in N3858. I have described a possible implementation based on coroutines and introduced the possibility of extending the sequence class with a set of composable operators.

If you are curious and want to play with my sample implementation, you can find a copy of the sources here. Nothing too fancy, just the code that I showed in this post.

Appendix – Coroutines in Win32

Having completed my long ramble on the “platform independent” aspects of C++ generators, it’s time to go back to the point we left open: how to implement, on Windows, the coroutines that we encapsulated in the __resumable_func class?

We saw that the Visual Studio CTP comes with a first implementation of resumable functions, built over the PPL task library and using Win32 fibers as side-stacks. Even though the CTP does not support generator functions yet, my first idea was to just extend the <pplawait.h> library to implement them. However the code there is specialized for resumable functions that suspend awaitingfor some task, andit turns out that we can reuse only part of their code because, even if we are still dealing with resumable functions, the logic of await and yield are quite different.

In the case of await, functions can be suspended (possibly multiple times) waiting for some other task to complete. This means switching to a fiber associated to the task after having set up a continuation that will be invoked after the task completes, to switch the control back to the current fiber. When the function terminates, the control goes back to the calling fiber, returning the single return value of the async resumable function.

In the case of yield, we never suspend to call external async methods. Instead, we can suspend multiple times going back to the calling fiber, each time by returning one of the values that compose the sequence. So, while the implementation of the await keyword needs to leverage the support of PPL tasks, the concept of generator functions does not imply any concurrency or multithreading and using the PPL is not necessary.

Actually, there are ways to implement yield with await) but I could not find a simple way to use the new __await keyword without spawning new threads (maybe this could be possible with a custom PPL scheduler?)

So I chose to write the code for coroutines myself; the idea here is not very different from the one I described in a very old post (it looks like I keep rewriting the same post :-)) but now I can take advantage of the fiber-based code from the CTP’s <pplawait.h> library.

Win32 Fibers

Let’s delve into the details of the implementation.  Before all, let me summarize once again the Win32 Fiber API.

Fibers were added to Windows NT to support cooperative multitasking. They can be thought as lightweight threads that must be manually scheduled by the application. In other words, fibers are a perfect tool to implement coroutines sequencing.

When a fiber is created, with CreateFiber, it is passed a fiber-start function. The OS then assigns it a separate stack and sets up execution to begin at this fiber-start function. To schedule a fiber we need to “switch” to it manually with SwitchToFiber and once it is running, a fiber can then suspend itself only by explicitly yielding execution to another fiber, also by calling SwitchToFiber.

SwitchToFiber only works from a fiber to another, so the first thing to do is to convert the current thread into a fiber, with ConvertThreadToFiber. Finally, when we have done using fibers, we can convert the main fiber back to a normal thread with ConvertFiberToThread.

The __resumable_func class

We want to put all the logic to handle the suspension and resumption of a function in the __resumable_func<T> class, as described before.

In our case we don’t need symmetric coroutines; we just need the ability of returning control to the calling fiber. So our class will contain a handle to the “caller” fiber and a handle to the fiber we want to run.

#include <functional>
#include <pplawait.h>

template <typename TRet>
class __resumable_func : __resumable_func_base
{
    typedef std::function<void(__resumable_func&)> TFunc;

    TFunc _func;
    TRet _currentValue;
    LPVOID _pFiber;
    LPVOID _pCallerFiber;
    Concurrency::details::__resumable_func_fiber_data* _pFuncData;

public:
    __resumable_func(TFunc func);
    ~__resumable_func();

    void yieldReturn(TRet value);
    void yieldBreak();
    void resume();

    const TRet& getCurrent() const const { return _currentValue; }
    bool isEos() const { return _pFiber == nullptr; }

private:
    static void yield();
    static VOID CALLBACK ResumableFuncFiberProc(PVOID lpParameter);
};

The constructor stores a copy of the generator function to run, creates a new fiber object specifying ResumableFuncFiberProc as the function to execute, and immediately switches the execution to this fiber:

    __resumable_func(TFunc func) :
        _currentValue(TRet()),
        _pFiber(nullptr),
        _func(func),
        _pFuncData(nullptr)
    {
        // Convert the current thread to a fiber. This is needed because the thread needs to "be"
        // a fiber in order to be able to switch to another fiber.
        ConvertCurrentThreadToFiber();
        _pCallerFiber = GetCurrentFiber();

        // Create a new fiber (or re-use an existing one from the pool)
        _pFiber = Concurrency::details::POOL CreateFiberEx(Concurrency::details::fiberPool.commitSize,
            Concurrency::details::fiberPool.allocSize, FIBER_FLAG_FLOAT_SWITCH, &ResumableFuncFiberProc, this);
        if (!_pFiber) {
            throw std::bad_alloc();
        }

        // Switch to the newly created fiber. When this "returns" the functor has either returned,
        // or issued an 'yield' statement.
        ::SwitchToFiber(_pFiber);

        _pFuncData->suspending = false;
        _pFuncData->Release();
    }

The fiber will start from the fiber procedure, which has the only task of running the generator function in the context of the fiber:

    // Entry proc for the Resumable Function Fiber.
    static VOID CALLBACK ResumableFuncFiberProc(PVOID lpParameter)
    {
        LPVOID threadFiber;

        // This function does not formally return, due to the SwitchToFiber call at the bottom.
        // This scope block is needed for the destructors of the locals in this block to fire
        // before we do the SwitchToFiber.
        {
            Concurrency::details::__resumable_func_fiber_data funcDataOnFiberStack;
            __resumable_func* pThis = (__resumable_func*)lpParameter;

            // The callee needs to setup some more stuff after we return (which would be either on
            // yield or an ordinary return). Hence the callee needs the pointer to the func_data
            // on our stack. This is not unsafe since the callee has a refcount on this structure
            // which means the fiber will continue to live.
            pThis->_pFuncData = &funcDataOnFiberStack;

            Concurrency::details::POOL SetFiberData(&funcDataOnFiberStack);

            funcDataOnFiberStack.threadFiber = pThis->_pCallerFiber;
            funcDataOnFiberStack.resumableFuncFiber = GetCurrentFiber();

            // Finally calls the function in the context of the fiber. The execution can be
            // suspended by calling yield
            pThis->_func(*pThis);

            // Here the function has completed. We set return to true meaning this is the
            // final 'real' return and not one of the 'yield' returns.
            funcDataOnFiberStack.returned = true;
            pThis->_pFiber = nullptr;
            threadFiber = funcDataOnFiberStack.threadFiber;
        }

        // Return to the calling fiber.
        ::SwitchToFiber(threadFiber);

        // On a normal fiber this function won't exit after this point. However, if the fiber is
        // in a fiber-pool and re-used we can get control back. So just exit this function, which
        // will cause the fiber pool to spin around and re-enter.
    }

There are two ways to suspend the execution of the generator function running in the fiber and to yield control back to the caller. The first is to yield a value, which will be stored in a data member:

    void yieldReturn(TRet value)
    {
        _currentValue = value;
        yield();
    }

The second is to immediately terminate the sequence, for example with a return statement or reaching the end of the function. The compiler should translate a return into a call to the yieldBreak method:

void yieldBreak()
{
    _pFiber = nullptr;
    yield();
}

To yield the control we just need to switch back to the calling fiber:

    static void yield()
    {
        _ASSERT(IsThreadAFiber());
        Concurrency::details::__resumable_func_fiber_data* funcDataOnFiberStack =
            Concurrency::details::__resumable_func_fiber_data::GetCurrentResumableFuncData();

        // Add-ref's the fiber. Even though there can only be one thread active in the fiber
        // context, there can be multiple threads accessing the fiber data.
        funcDataOnFiberStack->AddRef();

        _ASSERT(funcDataOnFiberStack);
        funcDataOnFiberStack->verify();

        // Mark as busy suspending. We cannot run the code in the 'then' statement
        // concurrently with the await doing the setting up of the fiber.
        _ASSERT(!funcDataOnFiberStack->suspending);
        funcDataOnFiberStack->suspending = true;

        // Make note of the thread that we're being called from (Note that we'll always resume
        // on the same thread).
        funcDataOnFiberStack->awaitingThreadId = GetCurrentThreadId();

        _ASSERT(funcDataOnFiberStack->resumableFuncFiber == GetCurrentFiber());

        // Return to the calling fiber.
        ::SwitchToFiber(funcDataOnFiberStack->threadFiber);
    }

Once we have suspended, incrementing the iterator will resume the execution by calling resume, which will switch to this object’s fiber:

    void resume()
    {
        _ASSERT(IsThreadAFiber());
        _ASSERT(_pFiber != nullptr);
        _ASSERT(_pFuncData != nullptr);
        _ASSERT(!_pFuncData->suspending);
        _ASSERT(_pFuncData->awaitingThreadId == GetCurrentThreadId());

        // Switch to the fiber. When this "returns" the functor has either returned, or issued
        // an 'yield' statement.
        ::SwitchToFiber(_pFiber);

        _ASSERT(_pFuncData->returned || _pFuncData->suspending);
        _pFuncData->suspending = false;
        if (_pFuncData->returned) {
            _pFiber = nullptr;
        }
        _pFuncData->Release();
    }

The destructor just needs to convert the current fiber back to a normal thread, but only when there are no more fibers running in the thread. For this reason we need to keep a per-thread fiber count, which is incremented every time we create a __resumable_funcand decremented every time we destroy it.

~__resumable_func()
{
    if (_pCallerFiber != nullptr) {
        ConvertFiberBackToThread();
    }
}

class __resumable_func_base
{
    __declspec(thread) static int ts_count;

protected:
    // Convert the thread to a fiber.
    static void ConvertCurrentThreadToFiber()
    {
        if (!IsThreadAFiber())
        {
            // Convert the thread to a fiber. Use FIBER_FLAG_FLOAT_SWITCH on x86.
            LPVOID threadFiber = ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
            if (threadFiber == NULL) {
                throw std::bad_alloc();
            }
            ts_count = 1;
        }
        else
        {
            ts_count++;
        }
    }

    // Convert the fiber back to a thread.
    static void ConvertFiberBackToThread()
    {
        if (--ts_count == 0)
        {
            if (ConvertFiberToThread() == FALSE) {
                throw std::bad_alloc();
            }
        }
    }
};
__declspec(thread) int __resumable_func_base::ts_count = 0;

And this is all we need to have resumable generators in C++, on Windows. The complete source code can be found here.

6 thoughts on “Generator functions in C++

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s