Created
September 17, 2018 19:33
-
-
Save Roger-luo/5e15a26a620c304bc57162e3f0d1e968 to your computer and use it in GitHub Desktop.
ATen/cpu/vec256/vec256_complex.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include "intrinsics.h" | |
#include "vec256_base.h" | |
#if defined(__AVX__) && !defined(_MSC_VER) | |
#include <sleef.h> | |
#endif | |
namespace at { | |
namespace vec256 { | |
namespace { | |
#if defined(__AVX__) && !defined(_MSC_VER) | |
using complex128 = std::complex<double>; | |
template <> class Vec256<complex128> { | |
private: | |
__m256d values; | |
public: | |
static constexpr int size = 4; | |
Vec256() {}; | |
Vec256(__m256d v) : values(v) {}; | |
Vec256(complex128 val) { | |
values = _mm256_set_pd(val.real(), val.imag(), val.real(), val.imag()); | |
} | |
operator __m256d() const { | |
return values; | |
} | |
template <int64_t mask> | |
static Vec256<complex128> blend(Vec256<complex128> a, Vec256<complex128> b) { | |
return _mm256_blend_pd(a.values, b.values, mask); | |
} | |
static Vec256<complex128> set(Vec256<complex128> a, Vec256<complex128> b, int64_t count = size) { | |
switch (count) { | |
case 0: | |
return a; | |
case 1: | |
return blend<1>(a, b); | |
case 3: | |
return blend<3>(a, b); | |
case 7: | |
return blend<7>(a, b); | |
} | |
return b; | |
} | |
static Vec256<complex128> loadu(const void* ptr, int64_t count = size) { | |
if (count == size) | |
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr)); | |
__at_align32__ double tmp_values[size]; | |
std::memcpy( | |
tmp_values, | |
reinterpret_cast<const complex128*>(ptr), | |
count * sizeof(complex128); | |
) | |
return _mm256_load_pd(tmp_values); | |
} | |
void store(void *ptr, int count = size) const { | |
if (count == size) { | |
_mm256_storeu_pd(reinterpret_cast<double *>(ptr), values); | |
} else { | |
double tmp_values[size]; | |
_mm256_storeu_pd(reinterpret_cast<double *>(tmp_values), values); | |
std::memcpy(ptr, tmp_values, count * sizeof(double)); | |
} | |
} | |
}; | |
#endif | |
}}} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment