Skip to content

Instantly share code, notes, and snippets.

@podgorskiy
Created April 6, 2020 22:42
Show Gist options
  • Save podgorskiy/17d0820e7e282cbfb40371e6b2185b16 to your computer and use it in GitHub Desktop.
Save podgorskiy/17d0820e7e282cbfb40371e6b2185b16 to your computer and use it in GitHub Desktop.
# Copyright 2020 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
def block_process(x, f):
width = x.shape[-1]
height = x.shape[-2]
block_size = 512
padding = 32
blocks = []
for i in range((width + block_size - padding - 1) // (block_size - padding)):
offset_x = i * (block_size - padding)
offset_x = min(offset_x + block_size, width) - block_size
w = min(offset_x + block_size, width) - offset_x
for j in range((height + block_size - padding - 1) // (block_size - padding)):
offset_y = j * (block_size - padding)
offset_y = min(offset_y + block_size, height) - block_size
h = min(offset_y + block_size, height) - offset_y
blocks.append((offset_x, offset_y, w, h))
results = []
for offset_x, offset_y, w, h in blocks:
res = f(x[:, :, offset_y:offset_y + h, offset_x:offset_x + w])
results.append(res)
output = []
for tensor in results[0]:
output.append(torch.zeros(*tensor.shape[:2], height, width, dtype=tensor.dtype))
counts = torch.zeros(*results[0][0].shape[:2], height, width, dtype=tensor.dtype)
weight_mask = torch.ones(*results[0][0].shape[:2], block_size, block_size, dtype=tensor.dtype)
for i in range(padding):
weight_mask[:, :, :, i] *= ((i + 1) / padding)
weight_mask[:, :, :, -i] *= ((i + 1) / padding)
for i in range(padding):
weight_mask[:, :, i, :] *= ((i + 1) / padding)
weight_mask[:, :, -i, :] *= ((i + 1) / padding)
for block, res in zip(blocks, results):
offset_x, offset_y, w, h = block
counts[:, :, offset_y:offset_y + h, offset_x:offset_x + w] += weight_mask
for o, r in zip(output, res):
o[:, :, offset_y:offset_y + h, offset_x:offset_x + w] += r * weight_mask
for o in output:
o /= counts
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment