Instantly share code, notes, and snippets.
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save xuhdev/08151fc261e140f9441a278db903d8ad to your computer and use it in GitHub Desktop.
TAGS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TensorIterator.h,6009 | |
namespace at {at58,1889 | |
struct DimCounter {DimCounter60,1905 | |
IntArrayRef shape;shape67,2097 | |
Range range;range68,2118 | |
DimVector values;values69,2133 | |
int64_t offset;offset70,2153 | |
struct CAFFE2_API OperandInfo {OperandInfo73,2175 | |
using StrideVector = SmallVector<int64_t, 6>;StrideVector74,2207 | |
OperandInfo() {}OperandInfo75,2255 | |
explicit OperandInfo(const Tensor& t) : tensor(t) {OperandInfo76,2274 | |
OperandInfo(const Tensor& t, Device device, ScalarType dtype)OperandInfo83,2435 | |
StrideVector stride_bytes;stride_bytes89,2649 | |
Tensor tensor;tensor94,2830 | |
Tensor original_tensor;original_tensor98,2954 | |
Device device = kCPU;device104,3307 | |
ScalarType dtype = ScalarType::Undefined;dtype105,3331 | |
bool is_type_defined() const { return dtype != ScalarType::Undefined; }is_type_defined107,3376 | |
TensorOptions options() const {options108,3450 | |
void* data = nullptr;data114,3638 | |
bool is_output = false;is_output116,3663 | |
bool is_read_write = false;is_read_write118,3690 | |
void validate() {validate120,3721 | |
enum class CommonDTypeStrategy : uint8_t {CommonDTypeStrategy129,3906 | |
NONE, // Do not compute a common dtypeNONE130,3949 | |
CHECK, // Compute and validate a common dtype but don't promote.CHECK131,3990 | |
PROMOTE_INPUTS, // Promote common dtype but only validate inputs (comparison ops have boolean PROMOTE_INPUTS132,4057 | |
PROMOTE // Promote to common dtype.PROMOTE133,4161 | |
struct CAFFE2_API TensorIterator {TensorIterator136,4203 | |
using DimMask = std::bitset<64>;DimMask137,4238 | |
using PtrVector = SmallVector<char*, 4>;PtrVector138,4273 | |
using StrideVector = SmallVector<int64_t, 6>;StrideVector139,4316 | |
TensorIterator() {}TensorIterator141,4365 | |
using loop_t = c10::function_ref<void(char** data, const int64_t* strides, int64_t size)>;loop_t153,4820 | |
using loop2d_t = c10::function_ref<void(char** data, const int64_t* strides, int64_t size0, inloop2d_t154,4913 | |
using loop_subiter_t = c10::function_ref<void(TensorIterator& subiter)>;loop_subiter_t156,5025 | |
int ndim() const { return shape_.size(); }ndim170,5709 | |
IntArrayRef shape() const { return shape_; }shape171,5754 | |
int ntensors() const { return operands_.size(); }ntensors173,5826 | |
int noutputs() const { return num_outputs_; }noutputs174,5878 | |
int ninputs() const { return ntensors() - noutputs(); }ninputs175,5926 | |
IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }strides191,6502 | |
ScalarType dtype(int arg=0) const { return operands_[arg].tensor.scalar_type(); }dtype193,6612 | |
ScalarType common_dtype() const { return common_dtype_; }common_dtype194,6696 | |
ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].dtype; }input_dtype195,6756 | |
Device device(int arg=0) const { return operands_[arg].device; }device196,6846 | |
DeviceType device_type(int arg=0) const { return device(arg).type(); }device_type197,6913 | |
int64_t element_size(int arg) const { return elementSize(dtype(arg)); }element_size198,6986 | |
const Tensor& tensor(int arg) const { return operands_[arg].tensor; }tensor202,7131 | |
Tensor& tensor(int arg) { return operands_[arg].tensor; }tensor203,7203 | |
Tensor output(int arg=0) const {output205,7264 | |
void cast_outputs() {cast_outputs210,7373 | |
Tensor input(int arg=0) const {input221,7775 | |
T scalar_value(int arg) {scalar_value245,8773 | |
StrideVector get_inner_strides() const { return get_dim_strides(0); }get_inner_strides270,9670 | |
bool should_accumulate() const { return accumulate_; }should_accumulate283,10244 | |
bool is_final_output() const { return final_output_; }is_final_output288,10463 | |
bool needs_dynamic_casting() const {needs_dynamic_casting290,10521 | |
void set_check_mem_overlap(bool check_mem_overlap) {set_check_mem_overlap294,10656 | |
void add_output(const Tensor& output) {add_output299,10779 | |
void add_output(const Tensor& input, Device device, ScalarType dtype) {add_output304,10882 | |
void add_input(const Tensor& input) {add_input309,11031 | |
void add_input(const Tensor& input, Device device, ScalarType dtype) {add_input313,11111 | |
void promote_common_dtype() {promote_common_dtype317,11239 | |
void dont_compute_common_dtype() {dont_compute_common_dtype321,11335 | |
void compute_common_dtype_only_for_inputs() {compute_common_dtype_only_for_inputs325,11433 | |
void dont_resize_outputs() {dont_resize_outputs329,11552 | |
DimVector shape_;shape_354,12121 | |
DimVector perm_;perm_355,12141 | |
NameVector names_;names_357,12185 | |
SmallVector<OperandInfo, 4> operands_;operands_359,12213 | |
int num_outputs_ = 0;num_outputs_360,12254 | |
CommonDTypeStrategy common_dtype_strategy_ = CommonDTypeStrategy::CHECK;common_dtype_strategy_361,12278 | |
ScalarType common_dtype_ = ScalarType::Undefined;common_dtype_362,12353 | |
bool has_coalesced_dimensions_ = false;has_coalesced_dimensions_363,12405 | |
bool accumulate_ = false;accumulate_364,12447 | |
bool resize_outputs_ = true;resize_outputs_365,12475 | |
bool is_reduction_ = false;is_reduction_366,12506 | |
bool allow_cpu_scalars_ = false;allow_cpu_scalars_367,12536 | |
bool promote_gpu_output_dtypes_ = false;promote_gpu_output_dtypes_368,12571 | |
bool final_output_ = true;final_output_369,12614 | |
bool check_mem_overlap_ = false;check_mem_overlap_370,12643 | |
bool have_differing_types_ = false;have_differing_types_371,12678 | |
bool all_ops_same_shape_ = false;all_ops_same_shape_372,12716 | |
struct CAFFE2_API SplitUntil32Bit {SplitUntil32Bit377,12937 | |
struct CAFFE2_API iterator {iterator378,12973 | |
iterator() {};iterator379,13004 | |
bool operator==(const iterator& other) const {operator ==385,13169 | |
bool operator!=(const iterator& other) const { return !(*this == other); }operator !=390,13421 | |
std::vector<std::unique_ptr<TensorIterator>> vec;vec393,13546 | |
SplitUntil32Bit(const TensorIterator& iter) : iter(iter) {}SplitUntil32Bit396,13606 | |
const TensorIterator& iter;iter402,13729 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment