Skip to content

Instantly share code, notes, and snippets.

@VOIDMalkuth
Created December 14, 2023 15:32
Show Gist options
  • Save VOIDMalkuth/c03022f63354f8195573a1f280a5b129 to your computer and use it in GitHub Desktop.
Save VOIDMalkuth/c03022f63354f8195573a1f280a5b129 to your computer and use it in GitHub Desktop.
Proposal Implementation for Paddle Slice Scatter
from paddle.base import core
from paddle.base.framework import Variable, default_main_program
from paddle.base.data_feeder import (
check_dtype,
check_type,
check_variable_and_dtype,
convert_dtype,
)
from paddle import _C_ops
from paddle.framework import (
LayerHelper,
convert_np_dtype_to_dtype_,
core,
dygraph_only,
in_dynamic_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
def _is_list_or_tuple_of(index, typenames, name, op_name)
if not isinstance(index, (tuple, list)):
raise TypeError(
"The type of '{}' in {} must be tuple or list of {}, but received {}.".format(
name,
op_name,
typenames,
type(index)
)
)
for s in index:
valid = False
for t in typenames:
if isinstance(s, typename):
return False
return True
def slice_scatter(x, y, axes, starts, ends, steps):
_is_list_or_tuple_of(axes, int, "axes", "slice_scatter")
_is_list_or_tuple_of(starts, int, "starts", "slice_scatter")
_is_list_or_tuple_of(ends, int, "ends", "slice_scatter")
_is_list_or_tuple_of(steps, int, "steps", "slice_scatter")
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': [],
'none_axes': [],
}
out = paddle.clone(x)
if in_dynamic_or_pir_mode():
return _C_ops.legacy.set_value(out, y, None, None, None, "starts", start, "ends", stop, "steps", step, "axes", axis)
else:
check_variable_and_dtype(
x,
'x',
[
'float16',
'float32',
'float64',
'uint8',
'int8',
'int16',
'int32',
'int64',
'complex64',
'complex128',
'bool',
],
'slice_scatter',
)
check_variable_and_dtype(
y,
'y',
[
convert_dtype(x.dtype),
],
'slice_scatter',
)
check_type(axis, 'axis', (int,), 'slice_scatter')
check_type(start, 'start', (int,), 'slice_scatter')
check_type(stop, 'stop', (int,), 'slice_scatter')
check_type(step, 'step', (int,), 'slice_scatter')
if step == 0:
raise ValueError(
f"Step should not be 0, but received step = {step}."
)
helper = LayerHelper("slice_scatter", **locals())
inputs = {'Input': x, 'ValueTensor': y}
attrs = {
'axes': [axis],
'starts': [start],
'ends': [stop],
'steps': [step],
'dtype': x.dtype
}
helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': out},
attrs=attrs,
)
return out
import paddle
if __name__ == "__main__":
paddle.enable_static()
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.full([7,8,9],9)
y = paddle.full([3,9],3)
z = slice_scatter(x, y, 1, 1, 7, 2)
exe = paddle.static.Executor()
out = exe.run(program, fetch_list=[z])
paddle.disable_static()
print(out)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# start out of place test
import sys
sys.path.append("/workspace/Paddle/test/legacy_test")
sys.path.append("/workspace/Paddle/test")
from slice_scatter import slice_scatter
# end out of place test
import unittest
from functools import reduce
import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.base import core
from paddle.base.layer_helper import LayerHelper
class TestSliceScatterBase(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.set_dtype()
self.set_param()
self.set_shape()
self.x_fill = 42
self.y_fill = -3
self.x_data = np.full(self.x_shape, self.x_fill).astype(self.x_dtype)
self.y_data = np.full(self.y_shape, self.y_fill).astype(self.y_dtype)
self.program = paddle.static.Program()
def set_dtype(self):
self.x_dtype = "float32"
self.y_dtype = "float32"
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = 1
self.stop = 4
self.step = 1
def _call_slice_scatter(self, x, y):
return slice_scatter(x, y, self.axis, self.start, self.stop, self.step)
def _get_answer(self):
self.x_data[:, 1:4, :] = self.y_data
class TestSliceScatterApi(TestSliceScatterBase):
def _run_static(self):
paddle.enable_static()
with paddle.static.program_guard(self.program):
x = paddle.full(self.x_shape, self.x_fill, dtype=self.x_dtype)
y = paddle.full(self.y_shape, self.y_fill, dtype=self.y_dtype)
z = self._call_slice_scatter(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(self.program, fetch_list=[z])
paddle.disable_static()
return out
def _run_dynamic(self):
paddle.disable_static()
x = paddle.full(self.x_shape, self.x_fill, dtype=self.x_dtype)
y = paddle.full(self.y_shape, self.y_fill, dtype=self.y_dtype)
z = self._call_slice_scatter(x, y)
out = z.numpy()
paddle.enable_static()
return out
def test_api(self):
static_out = self._run_static()
dynamic_out = self._run_dynamic()
self._get_answer()
error_msg = (
"\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}"
)
self.assertTrue(
(self.x_data == static_out).all(),
msg=error_msg.format("static", self.x_data, static_out),
)
self.assertTrue(
(self.x_data == dynamic_out).all(),
msg=error_msg.format("dynamic", self.x_data, dynamic_out),
)
# test step > 1
class TestSliceScatterStep1(TestSliceScatterApi):
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = 1
self.stop = 7
self.step = 2
def _get_answer(self):
self.x_data[:, 1:7:2, :] = self.y_data
# test step = -1
class TestSliceScatterStep2(TestSliceScatterApi):
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = 4
self.stop = 1
self.step = -1
def _get_answer(self):
self.x_data[:, 4:1:-1, :] = self.y_data
# test step = -2
class TestSliceScatterStep3(TestSliceScatterApi):
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = 7
self.stop = 1
self.step = -2
def _get_answer(self):
self.x_data[:, 7:1:-2, :] = self.y_data
# test start is none
class TestSetValueStartIsNone(TestSliceScatterApi):
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = None
self.stop = 3
self.step = 1
def _get_answer(self):
self.x_data[:, 0:3, :] = self.y_data
# test stop is none
class TestSetValueStopIsNone(TestSliceScatterApi):
def set_shape(self):
self.x_shape = [7, 8, 9]
self.y_shape = [3, 9]
def set_param(self):
self.axis = 1
self.start = 5
self.stop = None
self.step = 1
def _get_answer(self):
self.x_data[:, 5:8, :] = self.y_data
# test different dtypes
def create_test_value_dtype(parent, dtype):
class TestValueDtype(parent):
def set_dtype(self):
self.x_dtype = dtype
self.y_dtype = dtype
cls_name = "{}_{}".format(parent.__name__, dtype.capitalize())
TestValueDtype.__name__ = cls_name
globals()[cls_name] = TestValueDtype
def create_slice_scatter_test_value_dtype(dtype):
create_test_value_dtype(TestSliceScatterApi, dtype)
create_test_value_dtype(TestSliceScatterStep1, dtype)
create_test_value_dtype(TestSliceScatterStep2, dtype)
create_test_value_dtype(TestSliceScatterStep3, dtype)
create_slice_scatter_test_value_dtype("bool")
create_slice_scatter_test_value_dtype("int32")
create_slice_scatter_test_value_dtype("int64")
create_slice_scatter_test_value_dtype("float32")
create_slice_scatter_test_value_dtype("float64")
create_slice_scatter_test_value_dtype("float16")
create_slice_scatter_test_value_dtype("complex64")
create_slice_scatter_test_value_dtype("complex128")
# Test error
class TestError(TestSliceScatterApi):
pass
def _value_type_error(self):
error_type = ValueError if paddle.in_dynamic_mode() else TypeError
with self.assertRaises(error_type):
x = paddle.full([7, 8, 9], 42, dtype="float32")
y = [1]
z = self._call_slice_scatter(x, y)
def _dtype_error(self):
error_type = ValueError if paddle.in_dynamic_mode() else TypeError
with self.assertRaises(error_type):
x = paddle.full([7, 8, 9], 42, dtype="float32")
y = paddle.full([3, 9], -3, dtype="int64")
z = self._call_slice_scatter(x, y)
def _step_error(self):
with self.assertRaises(ValueError):
x = paddle.full([7, 8, 9], 42, dtype="float32")
y = paddle.full([3, 9], -3, dtype="float32")
z = slice_scatter(x, y, 1, 1, 4, 0)
def _shape_mismatch_static(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.full([7, 8, 9], 42, dtype="float32")
y = paddle.full([3, 9], -3, dtype="float32")
z = slice_scatter(x, y, 1, 1, 5, 1)
exe = paddle.static.Executor(paddle.CPUPlace())
with self.assertRaises(ValueError):
exe.run(program)
def _shape_mismatch(self):
with self.assertRaises(ValueError):
x = paddle.full([7, 8, 9], 42, dtype="float32")
y = paddle.full([3, 9], -3, dtype="float32")
z = slice_scatter(x, y, 1, 1, 5, 1)
def test_error(self):
paddle.enable_static()
with paddle.static.program_guard(self.program):
self._value_type_error()
self._dtype_error()
self._step_error()
self._shape_mismatch_static()
paddle.disable_static()
self._value_type_error()
self._dtype_error()
self._step_error()
self._shape_mismatch()
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment