Skip to content

Instantly share code, notes, and snippets.

@ALEXMORF
Created March 14, 2020 21:53
Show Gist options
  • Save ALEXMORF/11af1b1ca059f8028c945c900ef16611 to your computer and use it in GitHub Desktop.
Save ALEXMORF/11af1b1ca059f8028c945c900ef16611 to your computer and use it in GitHub Desktop.
a ray traversal shader that triggers dxc's validation error
#define RS "DescriptorTable(UAV(u0)), UAV(u1), UAV(u2), SRV(t0), SRV(t1), RootConstants(num32BitConstants=10, b0)"
typedef row_major float4x4 float4x4;
struct mat4
{
float4x4 Data;
};
struct prim
{
float3 Verts[3];
int ID;
};
struct bvh_node
{
float3 Min;
float3 Max;
uint Offset; // prim offset if leaf, secondChildOffset if internal
//NOTE(chen): data organization:
// u16 PrimCount; // 0 -> internal node
// u8 Axis;
// u8 Pad;
uint Data;
};
struct instance
{
mat4 InvT;
int NodeOffset;
int PrimOffset;
};
uint GetPrimCount(uint Data)
{
return Data & 0xffff;
}
uint GetAxis(uint Data)
{
return (Data >> 16) & 0xff;
}
#define T_MAX 10e31
struct ray_result
{
float ClosestT;
int PrimID;
float2 UV;
};
ray_result InitRayResult()
{
ray_result Res;
Res.ClosestT = T_MAX;
Res.PrimID = -1;
Res.UV = 0.0;
return Res;
}
float RayIntersectBox(float3 Ro, float3 InvRd, float3 Min, float3 Max)
{
float XMin, XMax, YMin, YMax, ZMin, ZMax;
if (InvRd.x > 0.0)
{
XMin = (Min.x - Ro.x) * InvRd.x;
XMax = (Max.x - Ro.x) * InvRd.x;
}
else
{
XMin = (Max.x - Ro.x) * InvRd.x;
XMax = (Min.x - Ro.x) * InvRd.x;
}
if (InvRd.y > 0.0)
{
YMin = (Min.y - Ro.y) * InvRd.y;
YMax = (Max.y - Ro.y) * InvRd.y;
}
else
{
YMin = (Max.y - Ro.y) * InvRd.y;
YMax = (Min.y - Ro.y) * InvRd.y;
}
if (InvRd.z > 0.0)
{
ZMin = (Min.z - Ro.z) * InvRd.z;
ZMax = (Max.z - Ro.z) * InvRd.z;
}
else
{
ZMin = (Max.z - Ro.z) * InvRd.z;
ZMax = (Min.z - Ro.z) * InvRd.z;
}
float IntervalMin = max(max(XMin, YMin), ZMin);
float IntervalMax = min(min(XMax, YMax), ZMax);
if (IntervalMin <= IntervalMax && (IntervalMin >= 0.0 || IntervalMax >= 0.0))
{
if (sign(IntervalMin) != sign(IntervalMax)) // means we are inside
{
return 0.0;
}
else // outside
{
return IntervalMin;
}
}
else
{
return T_MAX;
}
}
//http://iquilezles.org/www/articles/intersectors/intersectors.htm
float3 RayIntersectTri(float3 ro, float3 rd,
float3 v0, float3 v1, float3 v2)
{
float3 v1v0 = v1 - v0;
float3 v2v0 = v2 - v0;
float3 rov0 = ro - v0;
float3 n = cross( v1v0, v2v0 );
float3 q = cross( rov0, rd );
float d = 1.0/dot( rd, n );
float u = d*dot( -q, v2v0 );
float v = d*dot( q, v1v0 );
float t = d*dot( -n, rov0 );
if( u<0.0 || u>1.0 || v<0.0 || (u+v)>1.0 ) t = -1.0;
return float3( t, u, v );
}
#define STACK_SIZE 32
#define WG_SIZE 32
groupshared uint StackShared[STACK_SIZE*WG_SIZE];
ray_result RayTraceBottomLevel(float3 Ro, float3 Rd, float3 InvRd,
ray_result RayState, int StackIndex,
RWStructuredBuffer<bvh_node> BLASBuffer,
RWStructuredBuffer<prim> PrimBuffer,
int NodeOffset, int PrimOffset)
{
int StackBase = StackIndex;
int Cursor = StackIndex+WG_SIZE;
StackShared[Cursor] = 0;
while (Cursor > StackBase)
{
uint NodeIndex = StackShared[Cursor];
Cursor -= WG_SIZE;
bvh_node Node = BLASBuffer[NodeOffset+NodeIndex];
uint PrimCount = GetPrimCount(Node.Data);
float BoxT = RayIntersectBox(Ro, InvRd, Node.Min, Node.Max);
if (BoxT < RayState.ClosestT)
{
if (PrimCount == 0) // internal node
{
uint Axis = GetAxis(Node.Data);
if (Rd[Axis] > 0.0)
{
Cursor += WG_SIZE;
StackShared[Cursor] = Node.Offset;
Cursor += WG_SIZE;
StackShared[Cursor] = NodeIndex + 1;
}
else
{
Cursor += WG_SIZE;
StackShared[Cursor] = NodeIndex + 1;
Cursor += WG_SIZE;
StackShared[Cursor] = Node.Offset;
}
}
else // instance leaf
{
for (int PrimI = 0; PrimI < PrimCount; ++PrimI)
{
prim P = PrimBuffer[PrimOffset+Node.Offset+PrimI];
float3 Res = RayIntersectTri(Ro, Rd,
P.Verts[0],
P.Verts[1],
P.Verts[2]);
float T = Res.x;
if (T > 0.0 && T < RayState.ClosestT)
{
RayState.ClosestT = T;
RayState.PrimID = P.ID;
RayState.UV = Res.yz;
}
}
}
}
}
return RayState;
}
ray_result RayTraceTwoLevel(uint LocalID, float3 Ro, float3 Rd, float3 InvRd,
StructuredBuffer<bvh_node> TLASBuffer,
StructuredBuffer<instance> InstanceBuffer,
RWStructuredBuffer<bvh_node> BLASBuffer,
RWStructuredBuffer<prim> PrimBuffer)
{
ray_result RayState = InitRayResult();
int Cursor = LocalID;
StackShared[Cursor] = 0;
while (Cursor >= 0 && Cursor < STACK_SIZE)
{
uint NodeIndex = StackShared[Cursor];
Cursor -= WG_SIZE;
bvh_node Node = TLASBuffer[NodeIndex];
uint PrimCount = GetPrimCount(Node.Data);
float BoxT = RayIntersectBox(Ro, InvRd, Node.Min, Node.Max);
if (BoxT < RayState.ClosestT)
{
if (PrimCount == 0) // internal node
{
uint Axis = GetAxis(Node.Data);
if (Rd[Axis] > 0.0)
{
Cursor += WG_SIZE;
StackShared[Cursor] = Node.Offset;
Cursor += WG_SIZE;
StackShared[Cursor] = NodeIndex + 1;
}
else
{
Cursor += WG_SIZE;
StackShared[Cursor] = NodeIndex + 1;
Cursor += WG_SIZE;
StackShared[Cursor] = Node.Offset;
}
}
else // instance leaf
{
instance Instance = InstanceBuffer[Node.Offset];
float3 InstanceRo = mul(float4(Ro, 1.0), Instance.InvT.Data).xyz;
float3 InstanceRd = mul(float4(Rd, 0.0), Instance.InvT.Data).xyz;
float3 InstanceInvRd = rcp(InstanceRd);
RayState = RayTraceBottomLevel(InstanceRo, InstanceRd,
InstanceInvRd,
RayState, Cursor,
BLASBuffer, PrimBuffer,
Instance.NodeOffset,
Instance.PrimOffset);
}
}
}
return RayState;
}
RWTexture2D<float4> Output: register(u0);
RWStructuredBuffer<bvh_node> BLASBuffer: register(u1);
RWStructuredBuffer<prim> PrimBuffer: register(u2);
StructuredBuffer<bvh_node> TLASBuffer: register(t0);
StructuredBuffer<instance> InstanceBuffer: register(t1);
struct context
{
float3 CamP;
uint Pad0;
float3 CamAt;
uint Pad1;
int Width;
int Height;
};
ConstantBuffer<context> Context: register(b0);
[RootSignature(RS)]
[numthreads(8, 4, 1)]
void CS(uint2 ThreadID: SV_DispatchThreadID,
uint LocalID: SV_GroupIndex)
{
float X = 2.0 * float(ThreadID.x) / float(Context.Width) - 1.0;
float Y = 2.0 * float(ThreadID.y) / float(Context.Height) - 1.0;
Y = -Y;
X *= float(Context.Width) / float(Context.Height);
float FOV = radians(60.0);
float Z = 1.0 / tan(0.5*FOV);
float3 Ro = Context.CamP;
float3 At = Context.CamAt;
float3 CamZ = normalize(At - Ro);
float3 CamX = normalize(cross(float3(0, 1, 0), CamZ));
float3 CamY = normalize(cross(CamZ, CamX));
float3 Rd = normalize(CamX * X + CamY * Y + Z * CamZ);
float3 InvRd = rcp(Rd);
ray_result Res = RayTraceTwoLevel(LocalID, Ro, Rd, InvRd,
TLASBuffer, InstanceBuffer,
BLASBuffer, PrimBuffer);
float ClosestT = Res.ClosestT;
int PrimID = Res.PrimID;
float2 UV = Res.UV;
Output[ThreadID].rgb = saturate(ClosestT / 30.0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment