Created
December 14, 2023 15:32
-
-
Save VOIDMalkuth/c03022f63354f8195573a1f280a5b129 to your computer and use it in GitHub Desktop.
Proposal Implementation for Paddle Slice Scatter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from paddle.base import core | |
from paddle.base.framework import Variable, default_main_program | |
from paddle.base.data_feeder import ( | |
check_dtype, | |
check_type, | |
check_variable_and_dtype, | |
convert_dtype, | |
) | |
from paddle import _C_ops | |
from paddle.framework import ( | |
LayerHelper, | |
convert_np_dtype_to_dtype_, | |
core, | |
dygraph_only, | |
in_dynamic_mode, | |
in_dynamic_or_pir_mode, | |
in_pir_mode, | |
) | |
def _is_list_or_tuple_of(index, typenames, name, op_name) | |
if not isinstance(index, (tuple, list)): | |
raise TypeError( | |
"The type of '{}' in {} must be tuple or list of {}, but received {}.".format( | |
name, | |
op_name, | |
typenames, | |
type(index) | |
) | |
) | |
for s in index: | |
valid = False | |
for t in typenames: | |
if isinstance(s, typename): | |
return False | |
return True | |
def slice_scatter(x, y, axes, starts, ends, steps): | |
_is_list_or_tuple_of(axes, int, "axes", "slice_scatter") | |
_is_list_or_tuple_of(starts, int, "starts", "slice_scatter") | |
_is_list_or_tuple_of(ends, int, "ends", "slice_scatter") | |
_is_list_or_tuple_of(steps, int, "steps", "slice_scatter") | |
attrs = { | |
'axes': axes, | |
'starts': starts, | |
'ends': ends, | |
'steps': steps, | |
'decrease_axes': [], | |
'none_axes': [], | |
} | |
out = paddle.clone(x) | |
if in_dynamic_or_pir_mode(): | |
return _C_ops.legacy.set_value(out, y, None, None, None, "starts", start, "ends", stop, "steps", step, "axes", axis) | |
else: | |
check_variable_and_dtype( | |
x, | |
'x', | |
[ | |
'float16', | |
'float32', | |
'float64', | |
'uint8', | |
'int8', | |
'int16', | |
'int32', | |
'int64', | |
'complex64', | |
'complex128', | |
'bool', | |
], | |
'slice_scatter', | |
) | |
check_variable_and_dtype( | |
y, | |
'y', | |
[ | |
convert_dtype(x.dtype), | |
], | |
'slice_scatter', | |
) | |
check_type(axis, 'axis', (int,), 'slice_scatter') | |
check_type(start, 'start', (int,), 'slice_scatter') | |
check_type(stop, 'stop', (int,), 'slice_scatter') | |
check_type(step, 'step', (int,), 'slice_scatter') | |
if step == 0: | |
raise ValueError( | |
f"Step should not be 0, but received step = {step}." | |
) | |
helper = LayerHelper("slice_scatter", **locals()) | |
inputs = {'Input': x, 'ValueTensor': y} | |
attrs = { | |
'axes': [axis], | |
'starts': [start], | |
'ends': [stop], | |
'steps': [step], | |
'dtype': x.dtype | |
} | |
helper.append_op( | |
type="set_value", | |
inputs=inputs, | |
outputs={'Out': out}, | |
attrs=attrs, | |
) | |
return out | |
import paddle | |
if __name__ == "__main__": | |
paddle.enable_static() | |
program = paddle.static.Program() | |
with paddle.static.program_guard(program): | |
x = paddle.full([7,8,9],9) | |
y = paddle.full([3,9],3) | |
z = slice_scatter(x, y, 1, 1, 7, 2) | |
exe = paddle.static.Executor() | |
out = exe.run(program, fetch_list=[z]) | |
paddle.disable_static() | |
print(out) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# start out of place test | |
import sys | |
sys.path.append("/workspace/Paddle/test/legacy_test") | |
sys.path.append("/workspace/Paddle/test") | |
from slice_scatter import slice_scatter | |
# end out of place test | |
import unittest | |
from functools import reduce | |
import numpy as np | |
from op_test import OpTest, convert_float_to_uint16 | |
import paddle | |
from paddle.base import core | |
from paddle.base.layer_helper import LayerHelper | |
class TestSliceScatterBase(unittest.TestCase): | |
def setUp(self): | |
paddle.enable_static() | |
self.set_dtype() | |
self.set_param() | |
self.set_shape() | |
self.x_fill = 42 | |
self.y_fill = -3 | |
self.x_data = np.full(self.x_shape, self.x_fill).astype(self.x_dtype) | |
self.y_data = np.full(self.y_shape, self.y_fill).astype(self.y_dtype) | |
self.program = paddle.static.Program() | |
def set_dtype(self): | |
self.x_dtype = "float32" | |
self.y_dtype = "float32" | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = 1 | |
self.stop = 4 | |
self.step = 1 | |
def _call_slice_scatter(self, x, y): | |
return slice_scatter(x, y, self.axis, self.start, self.stop, self.step) | |
def _get_answer(self): | |
self.x_data[:, 1:4, :] = self.y_data | |
class TestSliceScatterApi(TestSliceScatterBase): | |
def _run_static(self): | |
paddle.enable_static() | |
with paddle.static.program_guard(self.program): | |
x = paddle.full(self.x_shape, self.x_fill, dtype=self.x_dtype) | |
y = paddle.full(self.y_shape, self.y_fill, dtype=self.y_dtype) | |
z = self._call_slice_scatter(x, y) | |
exe = paddle.static.Executor(paddle.CPUPlace()) | |
out = exe.run(self.program, fetch_list=[z]) | |
paddle.disable_static() | |
return out | |
def _run_dynamic(self): | |
paddle.disable_static() | |
x = paddle.full(self.x_shape, self.x_fill, dtype=self.x_dtype) | |
y = paddle.full(self.y_shape, self.y_fill, dtype=self.y_dtype) | |
z = self._call_slice_scatter(x, y) | |
out = z.numpy() | |
paddle.enable_static() | |
return out | |
def test_api(self): | |
static_out = self._run_static() | |
dynamic_out = self._run_dynamic() | |
self._get_answer() | |
error_msg = ( | |
"\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}" | |
) | |
self.assertTrue( | |
(self.x_data == static_out).all(), | |
msg=error_msg.format("static", self.x_data, static_out), | |
) | |
self.assertTrue( | |
(self.x_data == dynamic_out).all(), | |
msg=error_msg.format("dynamic", self.x_data, dynamic_out), | |
) | |
# test step > 1 | |
class TestSliceScatterStep1(TestSliceScatterApi): | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = 1 | |
self.stop = 7 | |
self.step = 2 | |
def _get_answer(self): | |
self.x_data[:, 1:7:2, :] = self.y_data | |
# test step = -1 | |
class TestSliceScatterStep2(TestSliceScatterApi): | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = 4 | |
self.stop = 1 | |
self.step = -1 | |
def _get_answer(self): | |
self.x_data[:, 4:1:-1, :] = self.y_data | |
# test step = -2 | |
class TestSliceScatterStep3(TestSliceScatterApi): | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = 7 | |
self.stop = 1 | |
self.step = -2 | |
def _get_answer(self): | |
self.x_data[:, 7:1:-2, :] = self.y_data | |
# test start is none | |
class TestSetValueStartIsNone(TestSliceScatterApi): | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = None | |
self.stop = 3 | |
self.step = 1 | |
def _get_answer(self): | |
self.x_data[:, 0:3, :] = self.y_data | |
# test stop is none | |
class TestSetValueStopIsNone(TestSliceScatterApi): | |
def set_shape(self): | |
self.x_shape = [7, 8, 9] | |
self.y_shape = [3, 9] | |
def set_param(self): | |
self.axis = 1 | |
self.start = 5 | |
self.stop = None | |
self.step = 1 | |
def _get_answer(self): | |
self.x_data[:, 5:8, :] = self.y_data | |
# test different dtypes | |
def create_test_value_dtype(parent, dtype): | |
class TestValueDtype(parent): | |
def set_dtype(self): | |
self.x_dtype = dtype | |
self.y_dtype = dtype | |
cls_name = "{}_{}".format(parent.__name__, dtype.capitalize()) | |
TestValueDtype.__name__ = cls_name | |
globals()[cls_name] = TestValueDtype | |
def create_slice_scatter_test_value_dtype(dtype): | |
create_test_value_dtype(TestSliceScatterApi, dtype) | |
create_test_value_dtype(TestSliceScatterStep1, dtype) | |
create_test_value_dtype(TestSliceScatterStep2, dtype) | |
create_test_value_dtype(TestSliceScatterStep3, dtype) | |
create_slice_scatter_test_value_dtype("bool") | |
create_slice_scatter_test_value_dtype("int32") | |
create_slice_scatter_test_value_dtype("int64") | |
create_slice_scatter_test_value_dtype("float32") | |
create_slice_scatter_test_value_dtype("float64") | |
create_slice_scatter_test_value_dtype("float16") | |
create_slice_scatter_test_value_dtype("complex64") | |
create_slice_scatter_test_value_dtype("complex128") | |
# Test error | |
class TestError(TestSliceScatterApi): | |
pass | |
def _value_type_error(self): | |
error_type = ValueError if paddle.in_dynamic_mode() else TypeError | |
with self.assertRaises(error_type): | |
x = paddle.full([7, 8, 9], 42, dtype="float32") | |
y = [1] | |
z = self._call_slice_scatter(x, y) | |
def _dtype_error(self): | |
error_type = ValueError if paddle.in_dynamic_mode() else TypeError | |
with self.assertRaises(error_type): | |
x = paddle.full([7, 8, 9], 42, dtype="float32") | |
y = paddle.full([3, 9], -3, dtype="int64") | |
z = self._call_slice_scatter(x, y) | |
def _step_error(self): | |
with self.assertRaises(ValueError): | |
x = paddle.full([7, 8, 9], 42, dtype="float32") | |
y = paddle.full([3, 9], -3, dtype="float32") | |
z = slice_scatter(x, y, 1, 1, 4, 0) | |
def _shape_mismatch_static(self): | |
program = paddle.static.Program() | |
with paddle.static.program_guard(program): | |
x = paddle.full([7, 8, 9], 42, dtype="float32") | |
y = paddle.full([3, 9], -3, dtype="float32") | |
z = slice_scatter(x, y, 1, 1, 5, 1) | |
exe = paddle.static.Executor(paddle.CPUPlace()) | |
with self.assertRaises(ValueError): | |
exe.run(program) | |
def _shape_mismatch(self): | |
with self.assertRaises(ValueError): | |
x = paddle.full([7, 8, 9], 42, dtype="float32") | |
y = paddle.full([3, 9], -3, dtype="float32") | |
z = slice_scatter(x, y, 1, 1, 5, 1) | |
def test_error(self): | |
paddle.enable_static() | |
with paddle.static.program_guard(self.program): | |
self._value_type_error() | |
self._dtype_error() | |
self._step_error() | |
self._shape_mismatch_static() | |
paddle.disable_static() | |
self._value_type_error() | |
self._dtype_error() | |
self._step_error() | |
self._shape_mismatch() | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment