-
-
Save rbehrends/2783e808550b1f11ae19 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import locks | |
| # Simple barrier implementation that keeps sync and wait separate. | |
| # | |
| # initBarrier() and deinitBarrier() initialize and destroy the | |
| # underlying OS structures. | |
| # | |
| # startBarrier() begins the barrier process. The `nthreads` argument | |
| # denotes the number of threads participating. Only one thread must call | |
| # startBarrier() on the same barrier. | |
| # | |
| # syncBarrier() is used to notify the barrier that the current thread | |
| # has finished its contribution to whatever the threads are | |
| # collaborating on. | |
| # | |
| # waitBarrier() blocks until all participating threads have called | |
| # syncBarrier(). This is separate from syncBarrier() so that threads can | |
| # continue to do work while waiting for other threads to finish their | |
| # contributions. | |
| # | |
| # endBarrier() blocks until all threads have called waitBarrier(). Only | |
| # one thread must call it at a time; if the thread also has call | |
| # waitBarrier(), it must call waitBarrier() before endBarrier(). After | |
| # endBarrier() is complete, startBarrier() can be called again. | |
| type TBarrier* = object | |
| counter: int | |
| waiting: int | |
| waitingEnd: bool | |
| lock: TLock | |
| cond, condEnd: TCond | |
| proc initBarrier*(): TBarrier = | |
| result.counter = 0 | |
| result.waiting = 0 | |
| initLock(result.lock) | |
| initCond(result.cond) | |
| initCond(result.condEnd) | |
| proc deinitBarrier*(bar: var TBarrier) = | |
| deinitLock(bar.lock) | |
| deinitCond(bar.cond) | |
| deinitCond(bar.condEnd) | |
| proc startBarrier*(bar: var TBarrier, nthreads: int) = | |
| var error = false | |
| acquire(bar.lock) | |
| if bar.counter == 0 and bar.waiting == 0: | |
| bar.counter = nthreads | |
| bar.waiting = nthreads | |
| else: | |
| error = true | |
| release(bar.lock) | |
| assert(not error) | |
| proc syncBarrier*(bar: var TBarrier) = | |
| # We can optimize this non-portably to use CAS or atomic dec. I.e., | |
| # decrement counter atomically and do not acquire the lock unless it | |
| # hits zero. However, if we do this, the decrement inside the critical | |
| # region must also use the same mechanism | |
| acquire(bar.lock) | |
| dec bar.counter | |
| if bar.counter == 0: | |
| signal(bar.cond) | |
| release(bar.lock) | |
| proc waitBarrier*(bar: var TBarrier) = | |
| # Again, this can be optimized somewhat by using an atomic decrement | |
| # on bar.waiting if bar.waiting would not become zero; note that this | |
| # may also require a memory barrier to ensure that bar.counter and | |
| # bar.waiting are in sync when testing if bar.counter == 0. | |
| # | |
| # An alternative approach is to store bar.counter and bar.waiting in | |
| # separate halfwords of a word that can be read and updated | |
| # atomically. | |
| acquire(bar.lock) | |
| if bar.counter == 0: | |
| dec bar.waiting | |
| release(bar.lock) | |
| return | |
| wait(bar.cond, bar.lock) | |
| dec bar.waiting | |
| # We have to be careful here, because Nimrod condition variables lack | |
| # a broadcast facility, so we send a signal for every one we consume | |
| # as long as at least one thread is still waiting. | |
| if bar.waiting > 0: | |
| signal(bar.cond) | |
| elif bar.waitingEnd: | |
| signal(bar.condEnd) | |
| release(bar.lock) | |
| proc endBarrier*(bar: var TBarrier) = | |
| acquire(bar.lock) | |
| let error = bar.waitingEnd | |
| if error: | |
| release(bar.lock) | |
| assert(not error) | |
| if bar.counter == 0 and bar.waiting == 0: | |
| release(bar.lock) | |
| return | |
| bar.waitingEnd = true | |
| wait(bar.condEnd, bar.lock) | |
| bar.waitingEnd = false | |
| release(bar.lock) | |
| when isMainModule: | |
| const n = 32 | |
| proc report(msg: string, t: int) = | |
| echo msg, " ", t | |
| type MyThread = TThread[int] | |
| var threads: array[1..n, MyThread] | |
| var mybarrier = initBarrier() | |
| proc testthread(i: int) {.thread.} = | |
| report "start", i | |
| syncBarrier(mybarrier) | |
| report "synced", i | |
| waitBarrier(mybarrier) | |
| report "done", i | |
| if i == 1: | |
| endBarrier(mybarrier) | |
| report "end", i | |
| startBarrier(mybarrier, n) | |
| for i in 1..n: | |
| createThread(threads[i], testthread, i) | |
| for i in 1..n: | |
| joinThread(threads[i]) | |
| report "complete", i |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment