Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created September 17, 2018 19:33
Show Gist options
  • Save Roger-luo/5e15a26a620c304bc57162e3f0d1e968 to your computer and use it in GitHub Desktop.
Save Roger-luo/5e15a26a620c304bc57162e3f0d1e968 to your computer and use it in GitHub Desktop.
ATen/cpu/vec256/vec256_complex.h
#pragma once
#include "intrinsics.h"
#include "vec256_base.h"
#if defined(__AVX__) && !defined(_MSC_VER)
#include <sleef.h>
#endif
namespace at {
namespace vec256 {
namespace {
#if defined(__AVX__) && !defined(_MSC_VER)
using complex128 = std::complex<double>;
template <> class Vec256<complex128> {
private:
__m256d values;
public:
static constexpr int size = 4;
Vec256() {};
Vec256(__m256d v) : values(v) {};
Vec256(complex128 val) {
values = _mm256_set_pd(val.real(), val.imag(), val.real(), val.imag());
}
operator __m256d() const {
return values;
}
template <int64_t mask>
static Vec256<complex128> blend(Vec256<complex128> a, Vec256<complex128> b) {
return _mm256_blend_pd(a.values, b.values, mask);
}
static Vec256<complex128> set(Vec256<complex128> a, Vec256<complex128> b, int64_t count = size) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 3:
return blend<3>(a, b);
case 7:
return blend<7>(a, b);
}
return b;
}
static Vec256<complex128> loadu(const void* ptr, int64_t count = size) {
if (count == size)
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align32__ double tmp_values[size];
std::memcpy(
tmp_values,
reinterpret_cast<const complex128*>(ptr),
count * sizeof(complex128);
)
return _mm256_load_pd(tmp_values);
}
void store(void *ptr, int count = size) const {
if (count == size) {
_mm256_storeu_pd(reinterpret_cast<double *>(ptr), values);
} else {
double tmp_values[size];
_mm256_storeu_pd(reinterpret_cast<double *>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
}
};
#endif
}}}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment