Skip to content

Instantly share code, notes, and snippets.

@rejunity
Last active October 28, 2019 15:29
Show Gist options
  • Save rejunity/c29123220b00675ab760c350bf9b1e9d to your computer and use it in GitHub Desktop.
Save rejunity/c29123220b00675ab760c350bf9b1e9d to your computer and use it in GitHub Desktop.
Python script to detect Thread Group Shared Memory bank conflicts
# Script to detect bank conflicts in Thread Group Shared Memory
# Supports optional Multicast (default: On), variable wave (default: 32) and bank count (default: 32)
# bank_conflicts(ptrs, wavesize=32, banks=32, multicast=True):
# Use:
# np.sum(bank_conflicts([thread.x*8 for thread in thread_group(8,8)]))
# np.sum(bank_conflicts([thread.id*2 for thread in thread_group(64)]))
# Conflicts per wave by bank number:
# bank_conflicts([thread.x*8 for thread in thread_group(8,8)])
# bank_conflicts([thread.id*2 for thread in thread_group(64)])
# How it works:
# 1) thread_group() returns array of objects with `x`, `y`, `z` and `id` props mimicking SV_GroupThreadID and SV_GroupIndex.
# 2) From that array one can simulate access pointers, for example [thread.id*2 for thread ... ] will stride memory accessing every 2nd element.
# 3) bank_conflicts() takes array of access pointers, splits into waves and returns list of conflicts per bank ID.
# 4) Finally, use np.sum() to get the total number of conflicts
import numpy as np
def bank_conflicts(ptrs, wavesize=32, banks=32, multicast=True):
def chunk(l, n): # chunk() to split group workload into waves
for i in range(0, len(l), n):
yield l[i:i + n]
def wave_bank_conflicts(ptrs):
C = np.zeros(banks) # initialize conflict counters to 0
P = C - 1 # initialize banked pointer access table to -1
for p in ptrs:
bank = p % banks
if P[bank] >= 0 and (P[bank] != p or not multicast):
C[bank] += 1 # found confict!
P[bank] = p
return list(C)
if len(ptrs) <= wavesize:
return wave_bank_conflicts(ptrs)
return [wave_bank_conflicts(wave) for wave in chunk(ptrs, wavesize)]
def thread_group(x,y=1,z=1):
class TID(object):
def __init__(self, x,y,z,id):
self.x = x; self.y = y; self.z = z; self.id = id
def __str__(self):
return str([self.x,self.y,self.z])
def __repr__(self):
return str([self.x,self.y,self.z])
all = []
for k in range(z):
for j in range(y):
for i in range(x):
ti = i + j*x + k*x*y
all.append(TID(i,j,k,ti))
return all
np.sum(bank_conflicts([thread.x*8 for thread in thread_group(8,8)]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment