Skip to content

Instantly share code, notes, and snippets.

@kice
Created December 9, 2019 19:34
Show Gist options
  • Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
// SMAGL is 0 or 2
// PixelType is uint8_t
static void warp_c(const uint8_t *srcp8, const uint8_t *edgep8, uint8_t *dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample)
{
const uint8_t *srcp = (const uint8_t *)srcp8;
const uint8_t *edgep = (const uint8_t *)edgep8;
uint8_t *dstp = (uint8_t *)dstp8;
src_stride /= sizeof(uint8_t);
edge_stride /= sizeof(uint8_t);
dst_stride /= sizeof(uint8_t);
int pixel_max = (1 << bits_per_sample) - 1;
const int x_limit_min = 0;
const int x_limit_max = (width - 1);
float scale = (depth << 8) / 65536.0f;
for (int y = 0; y < height; y++) {
float y_limit_min = -y;
float y_limit_max = (height - y - 1) - 1e-2f;
for (int x = 0; x < width; x++) {
// calculate displacement
int above, below;
if (y == 0)
above = edgep[x];
else
above = edgep[-edge_stride + x];
if (y == height - 1)
below = edgep[x];
else
below = edgep[edge_stride + x];
int left, right;
if (x == 0)
left = edgep[x];
else
left = edgep[x - 1];
if (x == width - 1)
right = edgep[x];
else
right = edgep[x + 1];
float _h = (left - right) * scale;
float _v = (above - below) * scale;
_v = std::clamp(_v, y_limit_min, y_limit_max);
float remainder_h = fmod(_h, 1.0);
remainder_h = remainder_h < 0.0 ? 1.0 + remainder_h : remainder_h;
float remainder_v = fmod(_v, 1.0);
remainder_v = remainder_v < 0.0 ? 1.0 + remainder_v : remainder_v;
int h = floor(_h) + x, v = floor(_v);
remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0;
h = std::min(h, x_limit_max);
h = std::max(h, x_limit_min);
// h and v contain the displacement now.
int s00 = srcp[v * src_stride + h];
int s01 = srcp[v * src_stride + h + 1];
int s10 = srcp[(v + 1) * src_stride + h];
int s11 = srcp[(v + 1) * src_stride + h + 1];
float s0 = s00 * (1.0 - remainder_h) + s01 * remainder_h + 0.5;
float s1 = s10 * (1.0 - remainder_h) + s11 * remainder_h + 0.5;
float s = s0 * (1.0 - remainder_v) + s1 * remainder_v + 0.5;
int val = nearbyint(s); // Use floor to match original
dstp[x] = std::min(std::max(val, 0), pixel_max);
}
srcp += src_stride;
edgep += edge_stride;
dstp += dst_stride;
}
}
@WolframRhodium
Copy link

WolframRhodium commented Dec 9, 2019

CALL

dim3 threadsPerBlock(32, 32);
assert(threadsPerBlock.z == 1 && threadsPerBlock.x * threadsPerBlock.y <= 1024);

dim3 numBlocks(
    (width + threadsPerBlock.x - 1) / threadsPerBlock.x, 
    (height + threadsPerBlock.y - 1) / threadsPerBlock.y
);

warp_cuda<<<threadsPerBlock, numBlocks>>>(...);

code

__global__ 
static void warp_cuda(const uint8_t * __restrict__ srcp8, const uint8_t * __restrict__ edgep8, uint8_t * __restrict__ dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample)
{
    const int x = blockIdx.x * blockDim.x + threadIdx.x;
    const int y = blockIdx.y * blockDim.y + threadIdx.y;

    if (y < height && x < width) {
        const uint8_t *srcp = (const uint8_t *)srcp8;
        const uint8_t *edgep = (const uint8_t *)edgep8;
        uint8_t *dstp = (uint8_t *)dstp8;

        src_stride /= sizeof(uint8_t);
        edge_stride /= sizeof(uint8_t);
        dst_stride /= sizeof(uint8_t);

        int pixel_max = (1 << bits_per_sample) - 1;

        const int x_limit_min = 0;
        const int x_limit_max = (width - 1);

        float scale = (depth << 8) / 65536.0f;

        // for (int y = 0; y < height; y++) {
        {

            float y_limit_min = -y;
            float y_limit_max = (height - y - 1) - 1e-2f;

            // for (int x = 0; x < width; x++) {
            {
                srcp += src_stride * y;
                edgep += edge_stride * y;
                dstp += dstp_stride * y;

                // calculate displacement

                int above, below;
                if (y == 0)
                    above = edgep[x];
                else
                    above = edgep[-edge_stride + x];

                if (y == height - 1)
                    below = edgep[x];
                else
                    below = edgep[edge_stride + x];

                int left, right;
                if (x == 0)
                    left = edgep[x];
                else
                    left = edgep[x - 1];

                if (x == width - 1)
                    right = edgep[x];
                else
                    right = edgep[x + 1];

                float _h = (left - right) * scale;
                float _v = (above - below) * scale;

                // _v = std::clamp(_v, y_limit_min, y_limit_max);
                _v = min(max(_v, y_limit_min), y_limit_max);

                // float remainder_h = fmodf(_h, 1.0f);
                float remainder_h = _h - nearbyintf(_h);
                remainder_h = remainder_h < 0.0f ? 1.0f + remainder_h : remainder_h;

                // float remainder_v = fmodf(_v, 1.0f);
                float remainder_v = _v - nearbyintf(_v);
                remainder_v = remainder_v < 0.0f ? 1.0f + remainder_v : remainder_v;

                // int h = floor(_h) + x, v = floor(_v);
                int h = __float2int_rd(_h) + x, v = __float2int_rd(_v);

                remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0f;

                h = min(h, x_limit_max);
                h = max(h, x_limit_min);

                // h and v contain the displacement now.

                int s00 = srcp[v * src_stride + h];
                int s01 = srcp[v * src_stride + h + 1];
                int s10 = srcp[(v + 1) * src_stride + h];
                int s11 = srcp[(v + 1) * src_stride + h + 1];

                float s0 = s00 * (1.0f - remainder_h) + s01 * remainder_h + 0.5f;
                float s1 = s10 * (1.0f - remainder_h) + s11 * remainder_h + 0.5f;
                float s  =  s0 * (1.0f - remainder_v) +  s1 * remainder_v + 0.5f;

                // int val = nearbyintf(s);
                int val = __float2int_rn(s); // Use floor to match original
                dstp[x] = min(max(val, 0), pixel_max);
            }

            // srcp += src_stride;
            // edgep += edge_stride;
            // dstp += dst_stride;
        }
}

@WolframRhodium
Copy link

WolframRhodium commented Dec 10, 2019

import torch

def rgb_warp(src, mask):
    N, C, H, W = src.shape
    device = src.device

    # dst = src.clone()

    scale = (8 * 256) / 65536

    pixel_max = 1.0

    x_limit_min = 0
    x_limit_max = W - 2

    # for n in range(N):
    if True:

        # for c in range(C):
        if True:

            # for y in range(H):
            y = torch.arange(H).view(1, 1, -1, 1)
            if True:

                y_limit_min = -y
                y_limit_max = (H - 2) - y

                # for x in range(W):
                x = torch.arange(W).view(1, 1, 1, -1)
                if True:

                    # above = mask[n, c, 0, x] if y == 0 else mask[n, c, y - 1, x]
                    above = torch.cat([mask[:, :, :1, :], mask[:, :, :-1, :]], dim=2)

                    # below = mask[n, c, -1, x] if y == H - 1 else mask[n, c, y + 1, x]
                    below = torch.cat([mask[:, :, 1:, :], mask[:, :, -1:, :]], dim=2)

                    # left = mask[n, c, y, x] if x == 0 else mask[n, c, y, x - 1]
                    left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)

                    # right = mask[n, c, y, -1] if x == W - 1 else mask[n, c, y, x + 1]
                    right = torch.cat([mask[:, :, :, 1:], mask[:, :, :, -1:]], dim=3)

                    _h = (left - right) * scale
                    _v = (above - below) * scale

                    # _v = torch.clamp(_v, y_limit_min, y_limit_max)
                    _v = torch.where(_v < y_limit_min, y_limit_min, _v)
                    _v = torch.where(_v > y_limit_min, y_limit_max, _v)

                    remainder_h = _h % 1
                    remainder_v = _v % 1

                    h = (torch.floor(_h) + x)
                    v = torch.floor(_v)

                    # remainder_h = remainder_h if (x_limit_max > h) and not (x_limit_min > h) else torch.zeros(1)
                    remainder_h = torch.where((x_limit_max > h) & (x_limit_min <= h), remainder_h, torch.zeros(1))

                    # h = torch.clamp(h, x_limit_min, x_limit_max)
                    _h = torch.where(h < x_limit_min, x_limit_min, h)
                    _h = torch.where(_h > x_limit_max, x_limit_max, _h)

                    h = h.long()
                    v = v.long()

                    """
                    s00 = src[n, c, v, h]
                    s01 = src[n, c, v, h + 1]
                    s10 = src[n, c, v + 1, h]
                    s11 = src[n, c, v + 1, h + 1]
                    """
                    src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
                    s00_indices = (v * W + h).flatten() # shape: (H * W)
                    s00 = torch.index_select(src, dim=2, index=s00_indices)
                    s01_indices = (v * W + h + 1).flatten() # shape: (H * W)
                    s01 = torch.index_select(src, dim=2, index=s01_indices)
                    s10_indices = ((v + 1) * W + h).flatten()[0, 0] # shape: (H * W)
                    s10 = torch.index_select(src, dim=2, index=s10_indices)
                    s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (H * W)
                    s11 = torch.index_select(src, dim=2, index=s11_indices)

                    # s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
                    s0 = torch.lerp(s00, s01, remainder_h) + 0.5

                    # s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
                    s1 = torch.lerp(s10, s11, remainder_h) + 0.5

                    # s = s0 * (1 - remainder_v) + s1 * remainder_v + 0.5
                    s = torch.lerp(s0, s1, remainder_v) + 0.5

                    # s = torch.clamp(s, 0.0, 1.0).view(1)
                    s = torch.where(s < 0.0, 0.0, s)
                    s = torch.where(s > 1.0, 1.0, s)

                    #dst[n, c, y, x:x + 1] = s
                    dst = s.view(N, C, H, W)

    return dst

@kice
Copy link
Author

kice commented Dec 10, 2019

def awarp(src, mask):
    N,C,H,W = src.shape
    
    scale = (8 * 256) / 65536

    pixel_max = 1.0

    y = torch.arange(H).view(1, 1, -1, 1).float()
    x = torch.arange(W).view(1, 1, 1, -1).float()

    coord = torch.arange(H*W).view(1, 1, H, W)

    # x_limit_min = torch.zeros_like(x)
    # x_limit_max = torch.full_like(x, W) - 2

    x_limit_min = 0
    x_limit_max = W - 2

    y_limit_min = -y
    y_limit_max = (H - 2) - y

    above = torch.cat([mask[:, :,  :1, : ], mask[:, :,   :-1, :  ]], dim=2)
    below = torch.cat([mask[:, :, 1:,  : ], mask[:, :, -1:,   :  ]], dim=2)
    left  = torch.cat([mask[:, :,  :,  :1], mask[:, :,   :,   :-1]], dim=3)
    right = torch.cat([mask[:, :,  :, 1: ], mask[:, :,   :, -1:  ]], dim=3)

    _h = (left - right) * scale
    _v = (above - below) * scale

    _v = torch.where(_v < y_limit_min, y_limit_min, _v)
    _v = torch.where(_v > y_limit_min, y_limit_max, _v)

    remainder_h = _h % 1
    remainder_v = _v % 1

    h = torch.floor(_h) + x
    v = torch.floor(_v + y)

    remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4

    h = torch.clamp(h, x_limit_min, x_limit_max)

    h = h.long()
    v = v.long()

    """
    s00 = src[n, c, v, h]
    s01 = src[n, c, v, h + 1]
    s10 = src[n, c, v + 1, h]
    s11 = src[n, c, v + 1, h + 1]
    """
    src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)

    s00_indices = (v * W + h).flatten() # shape: (1, 1, H * W)
    s00 = torch.index_select(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
    s01_indices = (v * W + h + 1).flatten() # shape: (1, 1, H * W)
    s01 = torch.index_select(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
    s10_indices = ((v + 1) * W + h).flatten() # shape: (1, 1, H * W)
    s10 = torch.index_select(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
    s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (1, 1, H * W)
    s11 = torch.index_select(src, dim=2, index=s11_indices).reshape(-1,C,H,W)

    s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
    s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
    s  =  s0 * (1 - remainder_v) +  s1 * remainder_v + 0.5

    s = torch.clamp(s, 0.0, pixel_max)
    dst = s.view(N, C, H, W)
    return dst

@WolframRhodium
Copy link

WolframRhodium commented Dec 11, 2019

Deprecated

https://gitlab.com/sr11/gist/blob/master/awarp/awarp.py

def awarp(src, mask):
    N,C,H,W = src.shape
    
    scale = (8 * 256) / 65536

    pixel_max = 1.0

    y = torch.arange(H).view(1, 1, -1, 1).float()
    x = torch.arange(W).view(1, 1, 1, -1).float()

    # coord = torch.arange(H*W).view(1, 1, H, W)

    # x_limit_min = torch.zeros_like(x)
    # x_limit_max = torch.full_like(x, W) - 2

    x_limit_min = 0
    x_limit_max = W - 2

    y_limit_min = -y
    y_limit_max = (H - 2) - y

    above = torch.cat([mask[:, :,  :1, : ], mask[:, :,   :-1, :  ]], dim=2)
    below = torch.cat([mask[:, :, 1:,  : ], mask[:, :, -1:,   :  ]], dim=2)
    left  = torch.cat([mask[:, :,  :,  :1], mask[:, :,   :,   :-1]], dim=3)
    right = torch.cat([mask[:, :,  :, 1: ], mask[:, :,   :, -1:  ]], dim=3)

    _h = (left - right) * scale
    _v = (above - below) * scale

    _v = torch.where(_v < y_limit_min, y_limit_min, _v)
    _v = torch.where(_v > y_limit_min, y_limit_max, _v)

    remainder_h = _h % 1
    remainder_v = _v % 1

    h = torch.floor(_h) + x
    v = torch.floor(_v + y)

    remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4

    h = torch.clamp(h, x_limit_min, x_limit_max)

    h = h.long()
    v = v.long()

    """
    s00 = src[n, c, v, h]
    s01 = src[n, c, v, h + 1]
    s10 = src[n, c, v + 1, h]
    s11 = src[n, c, v + 1, h + 1]
    """
    src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)

    # please also try replacing ".expand()" with ".repeat()"
    s00_indices = (v * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s00 = torch.gather(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
    s01_indices = (v * W + h + 1).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s01 = torch.gather(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
    s10_indices = ((v + 1) * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s10 = torch.gather(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
    s11_indices = ((h + 1) * W + (v + 1)).flatten(start_dim=2).expand(-1, C, -1) # shape: (N, C, H * W)
    s11 = torch.gather(src, dim=2, index=s11_indices).expand_as(src)

    s0 = torch.lerp(s00, s01, remainder_h) # + 0.5
    s1 = torch.lerp(s10, s11, remainder_h) # + 0.5
    s = torch.lerp(s0, s1, remainder_v) # + 0.5

    s = torch.clamp(s, 0.0, pixel_max)
    dst = s.view(N, C, H, W)
    return dst

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment