Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active February 9, 2023 11:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/e2e0c7dcaef9000fdb79d08e22b9918e to your computer and use it in GitHub Desktop.
Save vfdev-5/e2e0c7dcaef9000fdb79d08e22b9918e to your computer and use it in GitHub Desktop.
Vectorized pytorch interpolate uint8

RGBA Image resizing with a vectorized algorithm

Horizontal pass vectorized algorithm on RGBA data

Input data is stored as

input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]

Weights are float values computed for each output pixel and rescaled to uint16:

weights[i] = [w[i, 0], w[i, 1], ..., w[i, K - 1]]

We want to compute the output as following:

output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]

where

oR[i] = r[xmin[i]] * w[i, 0] + r[xmin[i] + 1] * w[i, 1] + ... + r[xmin[i] + K - 1] * w[i, K - 1]
oG[i] = g[xmin[i]] * w[i, 0] + g[xmin[i] + 1] * w[i, 1] + ... + g[xmin[i] + K - 1] * w[i, K - 1]
oB[i] = b[xmin[i]] * w[i, 0] + b[xmin[i] + 1] * w[i, 1] + ... + b[xmin[i] + K - 1] * w[i, K - 1]

Output computation with integers

oR[i] = r[xmin[i]] * w[i, 0] + r[xmin[i] + 1] * w[i, 1] + ... + r[xmin[i] + K - 1] * w[i, K - 1]

where r is uint8 and w is float.

Here is a way to perform computations in integer with a minimal precision loss.

  1. Rescale float weights into int16
  • find max float weight to estimate weights_precision
unsigned int weights_precision = 0;
for (weights_precision = 0; weights_precision < 22; weights_precision += 1) {
      int next_value = (int) (0.5 + w_max * (1 << (weights_precision + 1)));
      if (next_value >= (1 << 15))
            break;
}
  • transform float value into int16 value:
w_i16[i] = (int16) (sign(w_f32) * 0.5 + w_f32 * (1 << weights_precision));
  1. Compute output value using int dtype:
uint8 dst = ...
uint8 src = ...
int16 wts = ...
int output = 1 << (weights_precision - 1);

output += src[0] * wts[0];
output += src[1] * wts[1];
...
output += src] * wts];
output = (output >> weights_precision);

dst[o] = (uint8) clamp(output, 0, 255);

Vectorized version

As data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value.

Working register, avx2 = 32 uint8 places

reg = [0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0]

We can split K (size of weight vector for a given output index) as a sum: K = n * 4 + m * 2 + k. We load and process 4 weights values in a loop ("block 4") then we process 2 weights values in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").

  1. As we are doing computations in integer dtype, we add the offset (=1 << (weights_precision - 1)):
reg = [
      0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0
]
  1. Load weights. For "block 4" we load 4 int16 values (w0, w1) and (w2, w3). Each value then will be represented in the register with uint8 values wl_0 and wh_0:
w01 = [
      wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 | ... | ... | wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1
]

For example,

w01 = [
      183 45 0 64 183 45 0 64 | 183 45 0 64 183 45 0 64 | 183 45 0 64 183 45 0 64 | 183 45 0 64 183 45 0 64
]
w23 = [
      wl_2 wh_2 wl_3 wh_3 wl_2 wh_3 wl_2 wh_3 | ... | ... | wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3
]

On the next iteration we will load next pair of weights (w4, w5) as w45 and (w6, w7) as w67 in case of "block 4".

In case of "block 2" we will load 2 int16 values (w0, w1):

w01 = [
      wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 | ... | ... | wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1
]

And in case of "block 1" we will load only 1 int16 value w0:

w0 = [
      wl_0 wh_0 0 0 wl_0 wh_0 0 0 | ... | ... | wl_0 wh_0 0 0 wl_0 wh_0 0 0
]
  1. Load source data. Each RGBA pixel has 4 uint8 size, so half of 256-bits register (=16 uint8 places) can be filled with 4 pixels. To fill 32 uint8 places (=256 bits) we can load 4 pixels from two lines, e.g. r0-r3 and rr0-rr3 where ri is a red value from line0 and rri is a red value from line1.

Thus, we can process in parallel 2 lines. The number of loaded pixels determines block option. For "block 4" we load pixels 0-3:

data = [
      r0 g0 b0 a0 r1 g1 b1 a1 | r2 g2 b2 a2 r3 g3 b3 a3 | rr0 gg0 bb0 aa0 rr1 gg1 bb1 aa1 | rr2 gg2 bb2 aa2 rr3 gg3 bb3 aa3
]

For example,

data = [
      0 1 2 255 3 4 5 255 | 6 7 8 255 9 10 11 255 | 27 28 29 255 30 31 32 255 | 33 34 35 255 36 37 38 255
]

In case of "block 2", we load

data = [
      r0 g0 b0 a0 r1 g1 b1 a1 | 0 0 0 0 0 0 0 0 | rr0 gg0 bb0 aa0 rr1 gg1 bb1 aa1 | 0 0 0 0 0 0 0 0
]

and in case of "block 1", we load

data = [
      r0 g0 b0 a0 0 0 0 0 | 0 0 0 0 0 0 0 0 | rr0 gg0 bb0 aa0 0 0 0 0 | 0 0 0 0 0 0 0 0
]
  1. As we loaded weights only 2 values we have to split and shuffle the source data such we could correctly multiply r0 * w0 + r1 * w1 and r2 * w2 + r3 * w3. For "block 4" we obtain:
data_01 = [
      r0 0 r1 0 g0 0 g1 0 | b0 0 b1 0 a0 0 a1 0 | rr0 0 rr1 0 gg0 0 gg1 0 | bb0 0 bb1 0 aa0 0 aa1 0
]
data_23 = [
      r2 0 r3 0 g2 0 g3 0 | b2 0 b3 0 a2 0 a3 0 | rr2 0 rr3 0 gg2 0 gg3 0 | bb2 0 bb3 0 aa2 0 aa3 0
]

For "block 2" we will have

data_01 = [
      r0 0 r1 0 g0 0 g1 0 | b0 0 b1 0 a0 0 a1 0 | rr0 0 rr1 0 gg0 0 gg1 0 | bb0 0 bb1 0 aa0 0 aa1 0
]

and for "block 1" we will have

data_0 = [
      r0 0 0 0 g0 0 0 0 | b0 0 0 0 a0 0 0 0 | rr0 0 0 0 gg0 0 0 0 | bb0 0 0 0 aa0 0 0 0
]
  1. Multiply and add weights and source data using integer 32-bits precision. Integer 32-bits precision means the output will take 4 placeholders (a b c d).
# w01 = [
#       wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 | ... | ... | wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1
# ]

out01 = data_01 * w01

out01 = [
      (r0 0) * (wl_0 wh_0) + (r1 0) * (wl_1 wh_1), (g0 0) * (wl_0 wh_0) + (g1 0) * (wl_1 wh_1)     |
      (b0 0) * (wl_0 wh_0) + (b1 0) * (wl_1 wh_1), (a0 0) * (wl_0 wh_0), (a1 0) * (wl_1 wh_1)      |
      (rr0 0) * (wl_0 wh_0) + (rr1 0) * (wl_1 wh_1), (gg0 0) * (wl_0 wh_0) + (gg1 0) * (wl_1 wh_1) |
      (bb0 0) * (wl_0 wh_0) + (bb1 0) * (wl_1 wh_1), (aa0 0) * (wl_0 wh_0) + (a1 0) * (wl_1 wh_1)
]

where (pi 0) * (wl_j wh_j) + (pk 0) * (wl_n wh_n) = (out_0, out_1, out_2, out_3).

out23 = data_23 * w23

out23 = [
      (r2 0) * (wl_2 wh_2) + (r3 0) * (wl_3, wh_3), (g2 0) * (wl_2 wh_2) + (g3 0) * (wl_3 wh_3)    |
      (b2 0) * (wl_2 wh_2) + (b3 0) * (wl_3 wh_3), (a2 0) * (wl_2 wh_2) + (a3 0) * (wl_3 wh_3)     |
      (rr2 0) * (wl_2 wh_2) + (rr3 0) * (wl_3 wh_3), (gg2 0) * (wl_2 wh_2) + (gg3 0) * (wl_3 wh_3) |
      (bb2 0) * (wl_2 wh_2) + (bb3 0) * (wl_3 wh_3), (aa2 0) * (wl_2 wh_2) + (a3 0) * (wl_3 wh_3)]

For "block 1" we will have

out0 = [
      (r0 0) * (wl_0 wh_0), (g0 0) * (wl_0 wh_0)   |
      (b0 0) * (wl_0 wh_0), (a0 0) * (wl_0 wh_0)   |
      (rr0 0) * (wl_0 wh_0), (gg0 0) * (wl_0 wh_0) |
      (bb0 0) * (wl_0 wh_0), (aa0 0) * (wl_0 wh_0)
]

Here each element like (r0 0) * (wl_0 wh_0) represent int32 and takes 4 placeholders.

Output is accumulated with the results from previous iterations.

  1. Add registers out01 and out23 together in case of "block 4"
out1234 = [

      (r0 0) * (wl_0 wh_0) + (r1 0) * (wl_1 wh_1) + (r2 0) * (wl_2 wh_2) + (r3 0) * (wl_3, wh_3),
      (g0 0) * (wl_0 wh_0) + (g1 0) * (wl_1 wh_1) + (g2 0) * (wl_2 wh_2) + (g3 0) * (wl_3 wh_3) |

      (b0 0) * (wl_0 wh_0) + (b1 0) * (wl_1 wh_1) + (b2 0) * (wl_2 wh_2) + (b3 0) * (wl_3 wh_3),
      (a0 0) * (wl_0 wh_0), (a1 0) * (wl_1 wh_1) + (a2 0) * (wl_2 wh_2) + (a3 0) * (wl_3 wh_3) |

      (rr0 0) * (wl_0 wh_0) + (rr1 0) * (wl_1 wh_1) + (rr2 0) * (wl_2 wh_2) + (rr3 0) * (wl_3 wh_3),
      (gg0 0) * (wl_0 wh_0) + (gg1 0) * (wl_1 wh_1) + (gg2 0) * (wl_2 wh_2) + (gg3 0) * (wl_3 wh_3) |

      (bb0 0) * (wl_0 wh_0) + (bb1 0) * (wl_1 wh_1) + (bb2 0) * (wl_2 wh_2) + (bb3 0) * (wl_3 wh_3),
      (aa0 0) * (wl_0 wh_0) + (aa1 0) * (wl_1 wh_1) + (aa2 0) * (wl_2 wh_2) + (aa3 0) * (wl_3 wh_3)
]
  1. Shift back the output integer values (output = (output >> weights_precision))
out12 = out12 >> weights_precision
# or
out1234 = out1234 >> weights_precision
  1. Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
(a a a a b b b b | c c c c d d d d) -> (a a b b c c d d | 0 0 0 0 0 0 0 0)
  1. Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
(a a b b c c d d) -> (a b c d 0 0 0 0)
  1. Write the output into single uint32
(a b c d) -> x_uint32

Vertical pass vectorized algorithm on RGBA data

Input data is stored as

input = [
      r[0, 0], g[0, 0], b[0, 0], a[0, 0], r[0, 1], g[0, 1], b[0, 1], a[0, 1], r[0, 2], g[0, 2], b[0, 2], a[0, 2], ...
      r[1, 0], g[1, 0], b[1, 0], a[1, 0], r[1, 1], g[1, 1], b[1, 1], a[1, 1], r[1, 2], g[1, 2], b[1, 2], a[1, 2], ...
      ...
      r - 1, 0], g - 1, 0], b - 1, 0], a - 1, 0], r - 1, 1], g - 1, 1], b - 1, 1], a - 1, 1], r - 1, 2], g - 1, 2], b - 1, 2], a - 1, 2], ...
      ...
]

Weights are float values computed for each output pixel and rescaled to uint16:

weights[i] = [w[i, 0], w[i, 1], ..., w[i, K - 1]]

We want to compute the output as following:

output = [
      oR[0, 0], oG[0, 0], oB[0, 0], oA[0, 0], oR[0, 1], oG[0, 1], oB[0, 1], oA[0, 1], ...
]

where

oR[j, i] = r[ymin[j], i] * w[j, 0] + r[ymin[j] + 1, i] * w[j, 1] + ... + r[ymin[j] + K - 1] * w[j, K - 1]
oG[j, i] = g[ymin[j], i] * w[j, 0] + g[ymin[j] + 1, i] * w[j, 1] + ... + g[ymin[j] + K - 1] * w[j, K - 1]
oB[j, i] = b[ymin[j], i] * w[j, 0] + b[ymin[j] + 1, i] * w[j, 1] + ... + b[ymin[j] + K - 1] * w[j, K - 1]

Vectorized version

As data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value.

Working accumulating register, avx2 = 32 uint8 places

reg = [0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0]

We can split K (size of weight vector for a given output index) as a sum: K = m * 2 + k. We load and process 2 weights values in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").

  1. As we are doing computations in integer dtype, we add the offset (=1 << (weights_precision - 1)):
reg = [
      0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0 | 0 128 0 0 0 128 0 0
]
  1. Load weights. For "block 2" we load 2 int16 values (w0, w1). Each value then will be represented in the register with uint8 values wl_0 and wh_0:
w01 = [
      wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 | ... | ... | wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1
]

And in case of "block 1" we will load only 1 int16 value w0:

w0 = [
      wl_0 wh_0 0 0 wl_0 wh_0 0 0 | ... | ... | wl_0 wh_0 0 0 wl_0 wh_0 0 0
]
  1. Load source data. Each RGBA pixel has 4 uint8 size, so half of 256-bits register (=16 uint8 places) can be filled with 4 pixels. To fill 32 uint8 places (=256 bits) we can load 8 pixels from each line, e.g. r0-r7 and rr0-rr7 where ri is a red value from line0 and rri is a red value from line1.

For vertical pass we need to compute together values from different lines.

line0 = [
      r0 g0 b0 a0 r1 g1 b1 a1 | r2 g2 b2 a2 r3 g3 b3 a3 | r4 g4 b4 a4 r5 g5 b5 a5 | r6 g6 b6 a6 r7 g7 b7 a7
]

line1 = [
      rr0 gg0 bb0 aa0 rr1 gg1 bb1 aa1 | rr2 gg2 bb2 aa2 rr3 gg3 bb3 aa3 | rr4 gg4 bb4 aa4 rr5 gg5 bb5 aa5 | rr6 gg6 bb6 aa6 rr7 gg7 bb7 aa7
]

We process 8 pixels within each line in parallel and two lines contribute to the output: r0 * w0 + rr0 * w1. When it remains less then 8 pixels we can process 2 pixels within each line in parallel and finally just 1 pixel.

  1. We loaded weights 2 values as (wl_0 wh_0 wl_1 wh_1) thus we have to split and shuffle the source data such we could correctly multiply r0 * w0 + rr0 * w1 and g0 * w0 + gg0 * w1.
data_01_ll = [
      r0 0 rr0 0 g0 0 gg0 0 | b0 0 bb0 0 a0 0 aa0 0 | r1 0 rr1 0 g1 0 gg1 0 | b1 0 bb1 0 a1 0 aa1 0
]
data_01_lh = [
      r2 0 rr2 0 g2 0 gg2 0 | b2 0 bb2 0 a2 0 aa2 0 | r3 0 rr3 0 g3 0 gg3 0 | b3 0 bb3 0 a3 0 aa3 0
]
data_01_hl = [
      r4 0 rr4 0 g4 0 gg4 0 | b4 0 bb4 0 a4 0 aa4 0 | r5 0 rr5 0 g5 0 gg5 0 | b5 0 bb5 0 a5 0 aa5 0
]
data_01_hh = [
      r6 0 rr6 0 g6 0 gg6 0 | b6 0 bb6 0 a6 0 aa6 0 | r7 0 rr7 0 g7 0 gg7 0 | b7 0 bb7 0 a7 0 aa7 0
]

For "block 1" we will have

data_0_ll = [
      r0 0 0 0 g0 0 0 0 | b0 0 0 0 a0 0 0 0 | r1 0 0 0 g1 0 0 0 | b1 0 0 0 a1 0 0 0
]
data_0_lh = [
      r2 0 0 0 g2 0 0 0 | b2 0 0 0 a2 0 0 0 | r3 0 0 0 g3 0 0 0 | b3 0 0 0 a3 0 0 0
]
data_0_hl = [
      r4 0 0 0 g4 0 0 0 | b4 0 0 0 a4 0 0 0 | r5 0 0 0 g5 0 0 0 | b5 0 0 0 a5 0 0 0
]
data_0_hh = [
      r6 0 0 0 g6 0 0 0 | b6 0 0 0 a6 0 0 0 | r7 0 0 0 g7 0 0 0 | b7 0 0 0 a7 0 0 0
]
  1. Multiply and add weights and source data using integer 32-bits precision. Integer 32-bits precision means the output will take 4 placeholders (a b c d).
# w01 = [
#       wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 | ... | ... | wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1
# ]

out01_ll = data_01_ll * w01

out01_ll = [
      (r0 0) * (wl_0 wh_0) + (rr0 0) * (wl_1 wh_1), (g0 0) * (wl_0 wh_0) + (gg0 0) * (wl_1 wh_1) |
      (b0 0) * (wl_0 wh_0) + (bb0 0) * (wl_1 wh_1), (a0 0) * (wl_0 wh_0), (aa0 0) * (wl_1 wh_1)  |
      (r1 0) * (wl_0 wh_0) + (rr1 0) * (wl_1 wh_1), (g1 0) * (wl_0 wh_0) + (gg1 0) * (wl_1 wh_1) |
      (b1 0) * (wl_0 wh_0) + (bb1 0) * (wl_1 wh_1), (a1 0) * (wl_0 wh_0) + (aa1 0) * (wl_1 wh_1)
]

where (pi 0) * (wl_j wh_j) + (ppi 0) * (wl_n wh_n) = (out_0, out_1, out_2, out_3).

out01_lh = data_01_lh * w01

out01_lh = [
      (r2 0) * (wl_0 wh_0) + (rr2 0) * (wl_1 wh_1), (g2 0) * (wl_0 wh_0) + (gg2 0) * (wl_1 wh_1) |
      (b2 0) * (wl_0 wh_0) + (bb2 0) * (wl_1 wh_1), (a2 0) * (wl_0 wh_0), (aa2 0) * (wl_1 wh_1)  |
      (r3 0) * (wl_0 wh_0) + (rr3 0) * (wl_1 wh_1), (g3 0) * (wl_0 wh_0) + (gg3 0) * (wl_1 wh_1) |
      (b3 0) * (wl_0 wh_0) + (bb3 0) * (wl_1 wh_1), (a3 0) * (wl_0 wh_0) + (aa3 0) * (wl_1 wh_1)
]


out01_hl = ...
out01_hh = ...

For "block 1" we will have

out0_ll = [
      (r0 0) * (wl_0 wh_0), (g0 0) * (wl_0 wh_0) |
      (b0 0) * (wl_0 wh_0), (a0 0) * (wl_0 wh_0) |
      (r1 0) * (wl_0 wh_0), (g1 0) * (wl_0 wh_0) |
      (b1 0) * (wl_0 wh_0), (a1 0) * (wl_0 wh_0)
]

out0_lh = ...
out0_hl = ...
out0_hh = ...

Here each element like (r0 0) * (wl_0 wh_0) represent int32 and takes 4 placeholders.

  1. Shift back the output integer values (output = (output >> weights_precision))
out01_ll = out01_ll >> weights_precision
out01_lh = out01_lh >> weights_precision
out01_hl = out01_hl >> weights_precision
out01_hh = out01_hh >> weights_precision
  1. Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
(a a a a b b b b | c c c c d d d d) -> (a' a' b' b' c' c' d' d')

(out01_ll, out01_lh) -> out_01_l
(out01_hl, out01_hh) -> out_01_h
  1. Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
(a a b b | c c d d) -> (a' b' c' d')

(out01_l, out01_h) -> out_01
  1. Write the output into single uint32
(a b c d) -> x_uint32

WIP on Vectorized interpolation

  • Install Pillow-SIMD
pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd

Run benchmarks: nightly vs PR

wget https://raw.githubusercontent.com/pytorch/vision/main/torchvision/transforms/functional_tensor.py -O torchvision_functional_tensor.py
python -u run_bench_interp.py "output/$(date "+%Y%m%d-%H%M%S")-pr.pkl" --tag=PR

Output consistency with master pytorch

# On pytorch-nightly
python verif_interp2.py verif_expected --is_ref=True

# On PR
python verif_interp2.py verif_expected --is_ref=False

Some results

08/02/2023


PIL version:  9.0.0.post1
[--------------------------------------------------------------------- Resize --------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+gite6bdca1) PR  |   torchvision resize
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |    38.074 (+-0.541)    |        145.393 (+-1.645)        |   368.952 (+-1.475)
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |        112.422 (+-0.668)        |    74.104 (+-0.135)
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True     |   112.459 (+-0.690)    |        496.323 (+-3.041)        |   1560.163 (+-5.492)
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False    |                        |        363.923 (+-1.618)        |   186.801 (+-0.461)
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |   184.901 (+-1.414)    |        890.993 (+-2.717)        |  2949.707 (+-95.723)
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |        647.951 (+-4.457)        |   318.293 (+-0.674)
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=True    |   139.299 (+-0.729)    |        329.859 (+-1.688)        |   1242.137 (+-4.737)
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=False   |                        |        307.541 (+-2.314)        |   908.356 (+-1.292)
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |                        |         67.892 (+-0.188)        |   473.954 (+-0.875)
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |         34.854 (+-0.124)        |    87.333 (+-0.212)
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=True     |                        |        188.218 (+-1.294)        |   2064.724 (+-7.161)
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=False    |                        |         55.389 (+-0.175)        |   238.161 (+-0.517)
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |                        |        316.895 (+-1.609)        |  3929.540 (+-11.386)
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |         73.027 (+-0.227)        |   424.204 (+-1.261)
      4 torch.uint8 channels_last bilinear 270 -> 224 aa=True    |                        |        166.030 (+-0.901)        |   1489.629 (+-5.092)
      4 torch.uint8 channels_last bilinear 270 -> 224 aa=False   |                        |        143.489 (+-0.757)        |   992.293 (+-1.604)
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True    |    37.804 (+-0.178)    |        145.445 (+-0.874)        |   355.802 (+-1.438)
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False   |                        |        112.183 (+-0.704)        |   203.691 (+-0.742)
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True    |   112.137 (+-0.763)    |        496.563 (+-11.028)       |  1549.939 (+-10.290)
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False   |                        |        364.179 (+-2.418)        |   678.691 (+-2.422)
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |   184.557 (+-1.122)    |        891.174 (+-3.050)        |  2930.927 (+-11.987)
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |        647.634 (+-1.752)        |  1287.492 (+-877.768)
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=True   |   139.091 (+-1.009)    |        329.487 (+-1.858)        |   818.238 (+-1.593)
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=False  |                        |        308.697 (+-2.485)        |   1209.505 (+-4.367)
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=True    |                        |         87.350 (+-0.238)        |   460.749 (+-1.384)
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=False   |                        |         53.891 (+-0.200)        |   252.033 (+-0.734)
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=True    |                        |        257.106 (+-1.468)        |   2052.175 (+-6.119)
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=False   |                        |        124.054 (+-0.658)        |  909.929 (+-272.601)
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |                        |        442.343 (+-2.452)        |  3904.617 (+-11.139)
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |        199.340 (+-1.232)        |  4785.501 (+-106.702)
      4 torch.uint8 channels_first bilinear 270 -> 224 aa=True   |                        |        285.454 (+-3.306)        |   1073.443 (+-6.514)
      4 torch.uint8 channels_first bilinear 270 -> 224 aa=False  |                        |        264.808 (+-3.390)        |   1157.429 (+-4.480)

Times are in microseconds (us).

07/02/2023

Num threads: 1

PIL version:  9.0.0.post1
[-------------------------------------------------------------------- Resize --------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+gite6bdca1) PR  |   torchvision resize
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |    38.945 (+-0.222)    |        130.363 (+-0.612)        |   364.956 (+-2.782)
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |        108.715 (+-0.305)        |    72.821 (+-0.245)
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True     |   112.800 (+-0.394)    |        439.170 (+-0.637)        |   1596.292 (+-2.404)
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False    |                        |        360.557 (+-0.442)        |   185.144 (+-0.231)
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |   186.025 (+-0.873)    |        781.784 (+-4.723)        |   2941.418 (+-3.263)
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |        643.826 (+-2.035)        |   316.546 (+-0.371)
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=True    |   139.784 (+-0.302)    |        319.836 (+-1.226)        |   1238.219 (+-1.816)
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=False   |                        |        297.607 (+-3.890)        |   908.446 (+-1.849)
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |                        |         52.814 (+-0.490)        |   470.149 (+-7.399)
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |         31.900 (+-0.115)        |    86.144 (+-0.203)
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=True     |                        |        131.809 (+-1.086)        |   2099.700 (+-3.938)
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=False    |                        |         52.489 (+-0.074)        |   236.924 (+-0.330)
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |                        |        207.632 (+-1.031)        |   3934.327 (+-4.734)
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |         69.291 (+-0.169)        |   422.172 (+-0.420)
      4 torch.uint8 channels_last bilinear 270 -> 224 aa=True    |                        |        149.362 (+-0.545)        |   1484.460 (+-3.488)
      4 torch.uint8 channels_last bilinear 270 -> 224 aa=False   |                        |        127.503 (+-0.296)        |   992.280 (+-1.658)
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True    |    38.934 (+-0.066)    |        130.096 (+-0.442)        |   352.259 (+-0.315)
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False   |                        |        108.973 (+-0.755)        |   201.381 (+-0.268)
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True    |   112.429 (+-0.337)    |        439.524 (+-2.111)        |   1582.000 (+-2.005)
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False   |                        |        360.747 (+-0.596)        |   707.031 (+-4.563)
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |   186.654 (+-0.514)    |        781.276 (+-2.859)        |   2929.456 (+-7.766)
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |        643.654 (+-2.010)        |  1436.077 (+-45.418)
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=True   |   140.697 (+-0.760)    |        318.995 (+-1.630)        |   814.211 (+-2.972)
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=False  |                        |        295.188 (+-1.540)        |   1208.328 (+-1.880)
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=True    |                        |         71.006 (+-0.246)        |   456.845 (+-1.242)
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=False   |                        |         50.860 (+-0.104)        |   249.063 (+-0.477)
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=True    |                        |        199.646 (+-0.859)        |   2091.296 (+-2.892)
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=False   |                        |        120.245 (+-0.589)        |   950.757 (+-17.017)
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |                        |        330.196 (+-1.010)        |   3908.203 (+-4.160)
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |        194.640 (+-0.230)        |   4760.218 (+-9.555)
      4 torch.uint8 channels_first bilinear 270 -> 224 aa=True   |                        |        266.951 (+-1.631)        |   1069.409 (+-4.428)
      4 torch.uint8 channels_first bilinear 270 -> 224 aa=False  |                        |        243.640 (+-0.885)        |   1163.805 (+-1.581)

Times are in microseconds (us).

02/02/2023

  • Removed pointer from upsample_avx_bilinear
  • Avoid copy if num_channels=4 and channels_last
Num threads: 1

PIL version:  9.0.0.post1
[----------------------------------------------------------------------------- Resize ----------------------------------------------------------------------------]
                                                                |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |          38.6          |               395.9             |                   360.7
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |               363.4             |                    68.2
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |         112.5          |              1530.6             |                  1555.5
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |              1369.7             |                   179.9
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |         186.0          |              2652.6             |                  2935.8
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |              2507.9             |                   309.7
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |                        |                57.1             |                   466.0
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |                37.6             |                    81.1
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |                        |               131.6             |                  2093.5
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |                58.0             |                   231.4
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |                        |               204.7             |                  3926.6
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |                74.7             |                   418.0
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |          38.7          |               397.9             |                   348.7
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               361.9             |                   197.9
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |         112.2          |              1448.7             |                  1540.7
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |              1388.4             |                  1493.0
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |         186.0          |              2633.9             |                  2923.4
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              2585.5             |                  1271.8
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |                        |               208.1             |                   453.3
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               188.8             |                   245.1
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |                        |               748.0             |                  2043.2
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |               673.7             |                   864.4
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |                        |              1362.0             |                  3897.1
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              1230.4             |                  1837.6

Times are in microseconds (us).
  • recoded unpack rgb method
Num threads: 1

PIL version:  9.0.0.post1
[----------------------------------------------------------------------------- Resize ----------------------------------------------------------------------------]
                                                                |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |          38.9          |              132.6              |                   360.5
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |              111.6              |                    68.2
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |         113.7          |              443.0              |                  1554.3
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |              362.6              |                   180.1
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |         187.5          |              784.7              |                  2904.1
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |              645.3              |                   309.0
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |                        |               55.1              |                   464.7
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |               33.5              |                    80.9
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |                        |              135.8              |                  2065.8
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |               55.1              |                   231.5
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |                        |              209.8              |                  3873.7
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |               71.3              |                   411.1
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |          39.2          |              132.5              |                   348.6
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |              111.9              |                   199.4
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |         112.7          |              439.8              |                  1542.1
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |              362.2              |                  1569.3
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |         185.4          |              779.8              |                  2888.7
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              645.0              |                  1440.4
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |                        |               73.9              |                   453.4
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               53.5              |                   245.7
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |                        |              200.3              |                  2041.5
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |              122.8              |                   933.2
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |                        |              331.4              |                  3852.1
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              197.1              |                  2046.7

Times are in microseconds (us).

01/02/2023

  • Use tensor to allocate memory
  • Compute weights once
cd /tmp/pth/interpolate_vec_uint8/ && python -u check_interp.py
Torch version: 2.0.0a0+git7f72623Torch config: PyTorch built with:  - GCC 9.4
  - C++ Version: 201703
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2  - Build settings: BUILD_TYPE=Release, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=0, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=0, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 1

PIL version:  9.0.0.post1
[----------------------------------------------------------------------------- Resize ----------------------------------------------------------------------------]
                                                                |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |          38.6          |               346.7             |                   361.0
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |               327.8             |                    67.5
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |         112.1          |              1321.2             |                  1553.2
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |              1248.8             |                   179.6
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |         184.9          |              2429.3             |                  2910.2
      3 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |              2306.4             |                   309.6
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |                        |               208.1             |                   466.6
      4 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |               189.6             |                    80.7
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |                        |               744.2             |                  2053.8
      4 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |               674.7             |                   231.1
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True    |                        |              1359.0             |                  3886.2
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False   |                        |              1230.8             |                   412.0
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |          38.3          |               346.6             |                   349.3
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               328.0             |                   196.9
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |         112.3          |              1321.6             |                  1538.2
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |              1249.3             |                  1515.5
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |         185.0          |              2435.4             |                  2887.2
      3 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              2312.7             |                  2556.5
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |                        |               209.3             |                   453.2
      4 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               190.6             |                   244.8
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |                        |               745.8             |                  2278.6
      4 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |               730.4             |                  1360.3
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True   |                        |              1480.8             |                  4110.3
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False  |                        |              1311.5             |                  3714.8

Times are in microseconds (us).

30/01/2023 (Repro current results)

cd /tmp/pth/interpolate_vec_uint8/ && python -u check_interp.py
Torch version: 2.0.0a0+git7f72623Torch config: PyTorch built with:  - GCC 9.4  - C++ Version: 201703  - OpenMP 201511 (a.k.a. OpenMP 4.5)  - CPU capability usage: AVX2  - Build settings: BUILD_TYPE=Release, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=0, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=0, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 1

PIL version:  9.0.0.post1
[----------------------------------------------------------------------------- Resize ----------------------------------------------------------------------------]
                                                                |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True    |          41.3          |               395.4             |                   377.4
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False   |                        |               368.4             |                    67.8
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=True    |         112.2          |              1456.1             |                  1557.2
      3 torch.uint8 channels_last bilinear 520 -> 32 aa=False   |                        |              1372.7             |                   180.1
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=True   |          38.5          |               372.0             |                   349.0
      3 torch.uint8 channels_first bilinear 256 -> 32 aa=False  |                        |               356.8             |                   196.9
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=True   |         112.8          |              1449.1             |                  1543.7
      3 torch.uint8 channels_first bilinear 520 -> 32 aa=False  |                        |              1379.2             |                  1306.1

Times are in microseconds (us).
Num threads: 1

PIL version:  9.0.0.post1
[----------------------------------------------------------------------------- Resize -----------------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=True    |         148.4          |              628.7              |                  1269.7
      3 torch.uint8 channels_last bilinear 270 -> 224 aa=False   |                        |              608.8              |                   917.8
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=True   |         149.7          |              598.4              |                   772.5
      3 torch.uint8 channels_first bilinear 270 -> 224 aa=False  |                        |              569.8              |                  1300.6

Times are in microseconds (us).
Num threads: 1

PIL version:  9.4.0
check_interp.py:93: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  expected_pil = torch.from_numpy(np.asarray(output_pil_img)).clone().permute(2, 0, 1).contiguous()
[------------------------------------------------------------------------- Resize ------------------------------------------------------------------------]
                                                              |  Pillow (9.4.0)  |  torch (2.0.0a0+git7f72623) PR  |  torch (2.0.0a0+git7f72623) PR (float)
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True  |      223.3       |              382.2              |                  361.8

Times are in microseconds (us).

uint8 -> float -> resize -> uint8

Num threads: 1
[--------- Downsampling: torch.Size([3, 438, 906]) -> (320, 196) ---------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git943acd4
1 threads: ----------------------------------------------------------------
      channels_first contiguous  |       345.0       |         2530.7

Times are in microseconds (us).

[--------- Downsampling: torch.Size([3, 438, 906]) -> (460, 220) ---------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git943acd4
1 threads: ----------------------------------------------------------------
      channels_first contiguous  |       412.8       |         2947.4

Times are in microseconds (us).

[---------- Downsampling: torch.Size([3, 438, 906]) -> (120, 96) ---------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git943acd4
1 threads: ----------------------------------------------------------------
      channels_first contiguous  |       214.0       |         2124.6

Times are in microseconds (us).

[--------- Downsampling: torch.Size([3, 438, 906]) -> (1200, 196) --------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git943acd4
1 threads: ----------------------------------------------------------------
      channels_first contiguous  |       911.3       |         7560.4

Times are in microseconds (us).

[--------- Downsampling: torch.Size([3, 438, 906]) -> (120, 1200) --------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git943acd4
1 threads: ----------------------------------------------------------------
      channels_first contiguous  |       291.0       |         2700.7

Times are in microseconds (us).

30/11/2022 - fallback uint8 implementation

Num threads: 1
[---------------------------------- Downsampling: torch.Size([3, 438, 906]) -> (320, 196) ----------------------------------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git7a3055e, using uint8  |  1.14.0a0+git7a3055e, using float
1 threads: ------------------------------------------------------------------------------------------------------------------
      channels_first contiguous  |       348.8       |               3315.0               |               2578.3

Times are in microseconds (us).

[---------------------------------- Downsampling: torch.Size([3, 438, 906]) -> (460, 220) ----------------------------------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git7a3055e, using uint8  |  1.14.0a0+git7a3055e, using float
1 threads: ------------------------------------------------------------------------------------------------------------------
      channels_first contiguous  |       412.5       |               4231.5               |               3004.9

Times are in microseconds (us).

[----------------------------------- Downsampling: torch.Size([3, 438, 906]) -> (120, 96) ----------------------------------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git7a3055e, using uint8  |  1.14.0a0+git7a3055e, using float
1 threads: ------------------------------------------------------------------------------------------------------------------
      channels_first contiguous  |       216.4       |               1818.1               |               2286.3

Times are in microseconds (us).

[---------------------------------- Downsampling: torch.Size([3, 438, 906]) -> (1200, 196) ---------------------------------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git7a3055e, using uint8  |  1.14.0a0+git7a3055e, using float
1 threads: ------------------------------------------------------------------------------------------------------------------
      channels_first contiguous  |       907.3       |               9095.5               |               5861.1

Times are in microseconds (us).

[---------------------------------- Downsampling: torch.Size([3, 438, 906]) -> (120, 1200) ---------------------------------]
                                 |  PIL 9.0.0.post1  |  1.14.0a0+git7a3055e, using uint8  |  1.14.0a0+git7a3055e, using float
1 threads: ------------------------------------------------------------------------------------------------------------------
      channels_first contiguous  |       298.1       |               2753.7               |               2865.6

Times are in microseconds (us).
PIL version:  9.0.0.post1
[-------------------- Resize measurements ---------------------]
                                |  Pillow image  |  torch tensor
1 threads: -----------------------------------------------------
      1200 -> 256, torch.uint8  |      57.0      |     295.1

Times are in microseconds (us).
import pickle
from typing import List
from pathlib import Path
import unittest.mock
import numpy as np
import PIL.Image
import torch
import torch.utils.benchmark as benchmark
import fire
from torchvision_functional_tensor import resize
def pth_downsample_i8(img, mode, size, aa=True):
align_corners = False
if mode == "nearest":
align_corners = None
out = torch.nn.functional.interpolate(
img, size=size,
mode=mode,
align_corners=align_corners,
antialias=aa,
)
return out
def torchvision_resize(img, mode, size, aa=True):
return resize(img, size=size, interpolation=mode, antialias=aa)
if not hasattr(PIL.Image, "Resampling"):
resampling_map = {
"bilinear": PIL.Image.BILINEAR,
"nearest": PIL.Image.NEAREST,
"bicubic": PIL.Image.BICUBIC,
}
else:
resampling_map = {
"bilinear": PIL.Image.Resampling.BILINEAR,
"nearest": PIL.Image.Resampling.NEAREST,
"bicubic": PIL.Image.Resampling.BICUBIC,
}
def patched_as_column_strings(self):
concrete_results = [r for r in self._results if r is not None]
env = f"({concrete_results[0].env})" if self._render_env else ""
env = env.ljust(self._env_str_len + 4)
output = [" " + env + concrete_results[0].as_row_name]
for m, col in zip(self._results, self._columns or ()):
if m is None:
output.append(col.num_to_str(None, 1, None))
else:
if len(m.times) == 1:
spread = 0
else:
spread = float(torch.tensor(m.times, dtype=torch.float64).std(unbiased=len(m.times) > 1))
if col._trim_significant_figures:
spread = benchmark.utils.common.trim_sigfig(spread, m.significant_figures)
output.append(f"{m.median / self._time_scale:>3.3f} (+-{spread / self._time_scale:>3.3f})")
return output
def run_benchmark(c, dtype, size, osize, aa, mode, mf="channels_first", min_run_time=10, tag=""):
results = []
torch.manual_seed(12)
if dtype == torch.bool:
tensor = torch.randint(0, 2, size=(c, size, size), dtype=dtype)
elif dtype == torch.complex64:
real = torch.randint(0, 256, size=(c, size, size), dtype=torch.float32)
imag = torch.randint(0, 256, size=(c, size, size), dtype=torch.float32)
tensor = torch.complex(real, imag)
elif dtype == torch.int8:
tensor = torch.randint(-127, 127, size=(c, size, size), dtype=dtype)
else:
tensor = torch.randint(0, 256, size=(c, size, size), dtype=dtype)
expected_pil = None
pil_img = None
if dtype == torch.uint8 and c == 3 and aa:
np_array = tensor.clone().permute(1, 2, 0).contiguous().numpy()
pil_img = PIL.Image.fromarray(np_array)
output_pil_img = pil_img.resize((osize, osize), resample=resampling_map[mode])
expected_pil = torch.from_numpy(np.asarray(output_pil_img)).clone().permute(2, 0, 1).contiguous()
memory_format = torch.channels_last if mf == "channels_last" else torch.contiguous_format
tensor = tensor[None, ...].contiguous(memory_format=memory_format)
output = pth_downsample_i8(tensor, mode=mode, size=(osize, osize), aa=aa)
output = output[0, ...]
if expected_pil is not None:
abs_diff = torch.abs(expected_pil.float() - output.float())
mae = torch.mean(abs_diff)
max_abs_err = torch.max(abs_diff)
if mode == "bilinear":
assert mae.item() < 1.0, mae.item()
assert max_abs_err.item() < 1.0 + 1e-5, max_abs_err.item()
else:
raise RuntimeError(f"Unsupported mode: {mode}")
# PIL
if pil_img is not None:
results.append(
benchmark.Timer(
# pil_img = pil_img.resize((osize, osize), resample=resampling_map[mode])
stmt=f"data.resize(({osize}, {osize}), resample=resample_val)",
globals={
"data": pil_img,
"resample_val": resampling_map[mode],
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"Pillow ({PIL.__version__})",
).blocked_autorange(min_run_time=min_run_time)
)
# Tensor interp
results.append(
benchmark.Timer(
# output = pth_downsample_i8(tensor, mode=mode, size=(osize, osize), aa=aa)
stmt=f"fn(data, mode='{mode}', size=({osize}, {osize}), aa={aa})",
globals={
"data": tensor,
"fn": pth_downsample_i8
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"torch ({torch.__version__}) {tag}",
).blocked_autorange(min_run_time=min_run_time)
)
# Torchvision resize
results.append(
benchmark.Timer(
# output = torchvision_resize(tensor, mode=mode, size=(osize, osize), aa=aa)
stmt=f"fn(data, mode='{mode}', size=({osize}, {osize}), aa={aa})",
globals={
"data": tensor,
"fn": torchvision_resize
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"torchvision resize",
).blocked_autorange(min_run_time=min_run_time)
)
return results
def main(
output_filepath: str,
min_run_time: int = 10,
tag: str = "",
display: bool = True,
):
output_filepath = Path(output_filepath)
test_results = []
for mf in ["channels_last", "channels_first"]:
for c, dtype in [
(3, torch.uint8),
(4, torch.uint8),
]:
for size in [256, 520, 712, 270]:
if size == 270:
osize_aa_mode_list = [
(224, True, "bilinear"),
(224, False, "bilinear"),
]
else:
osize_aa_mode_list = [
(32, True, "bilinear"),
(32, False, "bilinear"),
]
for osize, aa, mode in osize_aa_mode_list:
test_results += run_benchmark(
c=c, dtype=dtype, size=size,
osize=osize, aa=aa, mode=mode, mf=mf,
min_run_time=min_run_time, tag=tag
)
with open(output_filepath, "wb") as handler:
output = {
"torch_version": torch.__version__,
"torch_config": torch.__config__.show(),
"num_threads": torch.get_num_threads(),
"pil_version": PIL.__version__,
"test_results": test_results,
}
pickle.dump(output, handler)
if display:
with unittest.mock.patch(
"torch.utils.benchmark.utils.compare._Row.as_column_strings", patched_as_column_strings
):
compare = benchmark.Compare(test_results)
compare.print()
if __name__ == "__main__":
torch.set_num_threads(1)
from datetime import datetime
print(f"Timestamp: {datetime.now().strftime('%Y%m%d-%H%M%S')}")
print(f"Torch version: {torch.__version__}")
print(f"Torch config: {torch.__config__.show()}")
print(f"Num threads: {torch.get_num_threads()}")
print("")
print("PIL version: ", PIL.__version__)
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment