import pfft
import numpy
from numpy.testing import assert_array_equal, assert_almost_equal
from runtests.mpi import MPITest
from mpi4py import MPI
[docs]def test_world():
world = MPI.COMM_WORLD
procmesh = pfft.ProcMesh(np=[world.size,], comm=world)
assert procmesh.comm == world
procmesh = pfft.ProcMesh(np=[world.size,], comm=None)
assert procmesh.comm == world
assert_array_equal(pfft.ProcMesh.split(2, None), pfft.ProcMesh.split(2, world))
assert_array_equal(pfft.ProcMesh.split(1, None), pfft.ProcMesh.split(1, world))
@MPITest(3)
def test_edges(comm):
procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_C2C,
[4, 4], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT
)
assert_array_equal(partition.i_edges[0], [0, 2, 4, 4])
assert_array_equal(partition.i_edges[1], [0, 4])
assert_array_equal(partition.o_edges[1], [0, 2, 4, 4])
assert_array_equal(partition.o_edges[0], [0, 4])
@MPITest(1)
def test_edges_padded(comm):
procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_R2C,
[16, 8], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT |
pfft.Flags.PFFT_PADDED_R2C
)
assert_array_equal(partition.i_edges[0], [0, 16])
assert_array_equal(partition.i_edges[1], [0, 8])
assert_array_equal(partition.o_edges[0], [0, 16])
assert_array_equal(partition.o_edges[1], [0, 5])
@MPITest(3)
def test_nino(comm):
procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_C2C,
[4, 8], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT)
assert_array_equal(partition.ni, [4, 8])
assert_array_equal(partition.no, [4, 8])
@MPITest(1)
def test_transposed(comm):
procmesh = pfft.ProcMesh(np=[1,], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_C2C,
[4, 8], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer = pfft.LocalBuffer(partition)
o = buffer.view_output()
i = buffer.view_input()
assert_array_equal(i.shape, (4, 8))
assert_array_equal(i.strides, (128, 16))
assert_array_equal(o.shape, (4, 8))
assert_array_equal(o.strides, (16, 64))
assert o.dtype == numpy.dtype('complex128')
assert i.dtype == numpy.dtype('complex128')
@MPITest(1)
def test_padded(comm):
procmesh = pfft.ProcMesh(np=[1,], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_R2C,
[4, 8], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_PADDED_R2C)
buffer = pfft.LocalBuffer(partition)
i = buffer.view_input()
o = buffer.view_output()
assert_array_equal(i.shape, (4, 8))
assert_array_equal(i.strides, (80, 8))
assert_array_equal(o.shape, (4, 5))
assert_array_equal(o.strides, (16, 64))
assert i.dtype == numpy.dtype('float64')
assert o.dtype == numpy.dtype('complex128')
@MPITest(1)
def test_correct_single(comm):
procmesh = pfft.ProcMesh(np=[1], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_C2C, [2, 2],
procmesh, flags=pfft.Flags.PFFT_ESTIMATE)
buffer1 = pfft.LocalBuffer(partition)
buffer2 = pfft.LocalBuffer(partition)
plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
buffer1.view_input()[:] = numpy.arange(4).reshape(2, 2)
correct = numpy.fft.fftn(buffer1.view_input())
plan.execute(buffer1, buffer2)
assert_array_equal(correct, buffer2.view_output())
@MPITest(1)
def test_raw(comm):
procmesh = pfft.ProcMesh(np=[1], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
procmesh, flags=pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer1 = pfft.LocalBuffer(partition)
assert buffer1.view_raw().size == 2 * partition.alloc_local
@MPITest(1)
def test_reuse_local_buffer(comm):
procmesh = pfft.ProcMesh(np=[1], comm=comm)
partition1 = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
procmesh, flags=pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT)
partition2 = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
procmesh, flags=pfft.Flags.PFFT_ESTIMATE)
buffer1 = pfft.LocalBuffer(partition1)
buffer2 = pfft.LocalBuffer(partition2, base=buffer1)
buffer3 = pfft.LocalBuffer(partition1)
assert buffer1 is not buffer2
assert buffer1.address == buffer2.address
assert buffer1 in buffer2
assert buffer2 in buffer1
assert buffer1 not in buffer3
assert buffer3 not in buffer1
assert buffer2 not in buffer3
assert buffer3 not in buffer2
@MPITest(1)
def test_transpose_1d_decom(comm):
procmesh = pfft.ProcMesh(np=[1,], comm=comm)
N = (1, 2, 3, 4)
partition = pfft.Partition(pfft.Type.PFFT_C2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer = pfft.LocalBuffer(partition)
i = buffer.view_input()
assert_array_equal(i.strides, [384, 192, 64, 16])
o = buffer.view_output()
assert_array_equal(o.strides, [192, 192, 64, 16])
@MPITest(1)
def test_transpose_2d_decom(comm):
procmesh = pfft.ProcMesh(np=[1,1], comm=comm)
N = (1, 2, 3, 4)
partition = pfft.Partition(pfft.Type.PFFT_C2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer = pfft.LocalBuffer(partition)
i = buffer.view_input()
assert_array_equal(i.strides, [384, 192, 64, 16])
o = buffer.view_output()
assert_array_equal(o.strides, [64, 192, 64, 16])
@MPITest(1)
def test_transpose_3d_decom(comm):
procmesh = pfft.ProcMesh(np=[1,1,1], comm=comm)
N = (1, 2, 3, 4, 5)
partition = pfft.Partition(pfft.Type.PFFT_C2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer = pfft.LocalBuffer(partition)
#FIXME: check with @mpip if this is correct.
i = buffer.view_input()
assert_array_equal(i.strides, [1920, 960, 320, 80, 16])
o = buffer.view_output()
assert_array_equal(o.strides, [80, 960, 320, 80, 16])
@MPITest((1, 4))
def test_correct_multi(comm):
procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm)
N = (2, 3)
data = numpy.arange(numpy.prod(N), dtype='complex128').reshape(N)
correct = numpy.fft.fftn(data)
result = numpy.zeros_like(data)
partition = pfft.Partition(pfft.Type.PFFT_C2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE)
buffer1 = pfft.LocalBuffer(partition)
buffer2 = pfft.LocalBuffer(partition)
plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
buffer1.view_input()[:] = data[partition.local_i_slice]
plan.execute(buffer1, buffer2)
result[partition.local_o_slice] = buffer2.view_output()
result = comm.allreduce(result)
assert_almost_equal(correct, result)
@MPITest(commsize=1)
def test_leak(comm):
for i in range(1024):
procmesh = pfft.ProcMesh(np=[1,1], comm=comm)
partition = pfft.Partition(pfft.Type.PFFT_C2C,
[128, 128, 128], procmesh,
pfft.Flags.PFFT_TRANSPOSED_OUT)
buffer = pfft.LocalBuffer(partition)
#FIXME: check with @mpip if this is correct.
i = buffer.view_input()
@MPITest([4])
def test_2d_on_2d_c2c(comm):
procmesh = pfft.ProcMesh(np=[2, 2], comm=comm)
N = (8, 8)
data = numpy.arange(numpy.prod(N), dtype='complex128').reshape(N)
correct = numpy.fft.fftn(data.copy())
result = numpy.zeros_like(correct)
partition = pfft.Partition(pfft.Type.PFFT_C2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE
| pfft.Flags.PFFT_TRANSPOSED_OUT
# | pfft.Flags.PFFT_DESTROY_INPUT
| pfft.Flags.PFFT_PRESERVE_INPUT
)
buffer1 = pfft.LocalBuffer(partition)
buffer2 = pfft.LocalBuffer(partition)
plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
buffer1.view_input()[:] = data[partition.local_i_slice]
plan.execute(buffer1, buffer2)
result[partition.local_o_slice] = buffer2.view_output()
result = comm.allreduce(result)
assert_almost_equal(correct, result)
@MPITest([1, 4])
def test_2d_on_2d_r2c(comm):
if comm.size == 1:
procmesh = pfft.ProcMesh(np=[1, 1], comm=comm)
else:
procmesh = pfft.ProcMesh(np=[2, 2], comm=comm)
N = (8, 8)
data = numpy.arange(numpy.prod(N), dtype='f8').reshape(N)
correct = numpy.fft.rfftn(data.copy())
result = numpy.zeros_like(correct)
partition = pfft.Partition(pfft.Type.PFFT_R2C, N,
procmesh, flags=pfft.Flags.PFFT_ESTIMATE
| pfft.Flags.PFFT_TRANSPOSED_OUT
| pfft.Flags.PFFT_DESTROY_INPUT
# | pfft.Flags.PADDED_R2C # doesn't work yet
)
buffer1 = pfft.LocalBuffer(partition)
buffer2 = pfft.LocalBuffer(partition)
plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
buffer1.view_input()[:] = data[partition.local_i_slice]
plan.execute(buffer1, buffer2)
result[partition.local_o_slice] = buffer2.view_output()
result = comm.allreduce(result)
assert_almost_equal(correct, result)