Skip to content

Instantly share code, notes, and snippets.

@fjarri
Created January 9, 2015 00:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fjarri/584bd83817a20c81ee67 to your computer and use it in GitHub Desktop.
Save fjarri/584bd83817a20c81ee67 to your computer and use it in GitHub Desktop.
import numpy
from reikna.cluda import ocl_api
from reikna.core import Annotation, Type, Parameter
from reikna.algorithms import PureParallel
import reikna.cluda.functions as functions
api = ocl_api()
thr = api.Thread.create()
size = (100, 100)
dtype = numpy.complex64
ftype = numpy.float64
ls_alpha = 0.5
lplfilter = PureParallel(
[
Parameter('output', Annotation(Type(dtype, shape=size), 'o')),
Parameter('otfi', Annotation(Type(dtype, shape=size), 'i')),
Parameter('fft_lpli', Annotation(Type(dtype, shape=size), 'i'))],
"""
${otfi.ctype} otf = ${otfi.load_same};
${fft_lpli.ctype} fft_lpl = ${fft_lpli.load_same};
float divf = ${mul}(${norm}(fft_lpl) + ${norm}(otf), ${lsa});
${output.store_same}(${div}(${conj}(otf), divf));
""",
render_kwds=dict(
mul=functions.mul(ftype, ftype, out_dtype=ftype),
div=functions.div(dtype, ftype, out_dtype=dtype),
norm=functions.norm(dtype),
conj=functions.conj(dtype),
lsa=ls_alpha))
lplfilter_c = lplfilter.compile(thr)
otfi = (numpy.random.normal(size=size) + 1j * numpy.random.normal(size=size)).astype(dtype)
fft_lpli = (numpy.random.normal(size=size) + 1j *numpy.random.normal(size=size)).astype(dtype)
otfi_dev = thr.to_device(otfi)
fft_lpli_dev = thr.to_device(fft_lpli)
output_dev = thr.array(size, dtype)
lplfilter_c(output_dev, otfi_dev, fft_lpli_dev)
output = output_dev.get()
output_ref = otfi.conj() / (ls_alpha * (numpy.abs(fft_lpli) ** 2 + numpy.abs(otfi) ** 2))
print(numpy.allclose(output, output_ref))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment