from __future__ import absolute_import
from warnings import warn
warn("This module is deprecated and likely no longer maintained; a maintained version is moved to cosmo4d to minimize changes in pmesh.", DeprecationWarning)
import numpy
from abopt.vmad2 import ZERO, Engine, statement, programme, CodeSegment, Literal
from abopt.abopt2 import VectorSpace
from pmesh.pm import ParticleMesh, RealField, ComplexField
[docs]def nyquist_mask(factor, v):
# any nyquist modes are set to 0 if the transfer function is complex
mask = (numpy.imag(factor) == 0) | \
~numpy.bitwise_and.reduce([(ii == 0) | (ii == ni // 2) for ii, ni in zip(v.i, v.Nmesh)])
return factor * mask
[docs]class ParticleMeshVectorSpace(VectorSpace):
def __init__(self, pm, q):
self.qshape = q.shape
self.pm = pm
[docs] def addmul(self, a, b, c, p=1):
if isinstance(b, RealField):
r = b.copy()
r[...] = a + b * c ** p
return r
elif isinstance(b, ComplexField):
r = b.copy()
if isinstance(c, ComplexField):
c = c.plain
if isinstance(a, ComplexField):
a = a.plain
r.plain[...] = a + b.plain * c ** p
return r
elif numpy.isscalar(b):
return a + b * c ** p
elif isinstance(b, numpy.ndarray):
assert len(b) == self.qshape[0]
return a + b * c ** p
else:
raise TypeError("type unknown")
[docs] def dot(self, a, b):
if type(a) != type(b):
raise TypeError("type mismatch")
if isinstance(a, RealField):
return a.cdot(b)
elif isinstance(a, ComplexField):
return a.cdot(b)
elif isinstance(a, numpy.ndarray):
assert len(a) == len(b)
assert len(a) == self.qshape[0]
return self.pm.comm.allreduce(a.dot(b))
else:
raise TypeError("type unknown")
[docs]class ParticleMeshEngine(Engine):
def __init__(self, pm, q=None):
self.pm = pm
if q is None:
q = pm.generate_uniform_particle_grid(shift=0.0, dtype='f4')
self.q = q
self.vs = ParticleMeshVectorSpace(self.pm, self.q)
[docs] @programme(ain=['s'], aout=['x'])
def get_x(engine, s, x):
code = CodeSegment(engine)
code.add(x1='s', x2=Literal(engine.q), y='x')
return code
[docs] @statement(aout=['real'], ain=['complex'])
def c2r(engine, real, complex):
real[...] = complex.c2r()
@c2r.defvjp
def _(engine, _real, _complex):
_complex[...] = _real.c2r_vjp()
@c2r.defjvp
def _(engine, real_, complex_):
real_[...] = complex_.c2r()
[docs] @statement(aout=['complex'], ain=['real'])
def r2c(engine, complex, real):
complex[...] = real.r2c()
@r2c.defvjp
def _(engine, _complex, _real):
_real[...] = _complex.r2c_vjp()
@r2c.defjvp
def _(engine, complex_, real_):
complex_[...] = real_.r2c()
[docs] @statement(aout=['complex'], ain=['complex'])
def decompress(engine, complex):
return
@decompress.defvjp
def _(engine, _complex):
_complex.decompress_vjp(out=Ellipsis)
@decompress.defjvp
def _(engine, complex_):
pass # XXX: is this correct?
@staticmethod
def _lowpass_filter(k, v, Neff):
k0s = 2 * numpy.pi / v.BoxSize
mask = numpy.bitwise_and.reduce([abs(ki) <= Neff//2 * k0 for ki, k0 in zip(k, k0s)])
return v * mask
[docs] @statement(aout=['real'], ain=['real'])
def lowpass(engine, real, Neff):
real.r2c(out=Ellipsis).apply(
lambda k, v: engine._lowpass_filter(k, v, Neff),
out=Ellipsis).c2r(out=Ellipsis)
@lowpass.defvjp
def _(engine, _real, Neff):
_real.c2r_vjp().apply(
lambda k, v: engine._lowpass_filter(k, v, Neff),
out=Ellipsis).r2c_vjp(out=Ellipsis)
@lowpass.defjvp
def _(engine, real_, Neff):
real_.r2c().apply(
lambda k, v: engine._lowpass_filter(k, v, Neff),
out=Ellipsis).c2r(out=Ellipsis)
[docs] @statement(aout=['layout'], ain=['x'])
def decompose(engine, layout, x):
pm = engine.pm
layout[...] = pm.decompose(x)
@decompose.defvjp
def _(engine, _layout, _x):
_x[...] = ZERO
@decompose.defjvp
def _(engine, layout_, x_):
layout_[...] = ZERO
[docs] @statement(aout=['mesh'], ain=['x', 'layout'])
def paint(engine, x, mesh, layout):
pm = engine.pm
N = pm.comm.allreduce(len(x))
mesh[...] = pm.paint(x, layout=layout, hold=False)
# to have 1 + \delta on the mesh
mesh[...][...] *= 1.0 * pm.Nmesh.prod() / N
@paint.defvjp
def _(engine, _x, _mesh, x, layout, _layout):
pm = engine.pm
_layout[...] = ZERO
N = pm.comm.allreduce(len(x))
_x[...], junk = pm.paint_vjp(_mesh, x, layout=layout, out_mass=False)
_x[...][...] *= 1.0 * pm.Nmesh.prod() / N
@paint.defjvp
def _(engine, x_, mesh_, x, layout, layout_):
pm = engine.pm
if x_ is ZERO: x_ = None
mesh_[...] = pm.paint_jvp(x, v_pos=x_, layout=layout)
[docs] @statement(aout=['value'], ain=['x', 'mesh', 'layout'])
def readout(engine, value, x, mesh, layout):
pm = engine.pm
N = pm.comm.allreduce(len(x))
value[...] = mesh.readout(x, layout=layout)
@readout.defvjp
def _(engine, _value, _x, _mesh, x, layout, mesh):
pm = engine.pm
_mesh[...], _x[...] = mesh.readout_vjp(x, _value, layout=layout)
@readout.defjvp
def _(engine, value_, x_, mesh_, x, layout, mesh, layout_):
pm = engine.pm
if mesh_ is ZERO: mesh_ = None
if x_ is ZERO: x_ = None
value_[...] = mesh.readout_jvp(x, v_self=mesh_, v_pos=x_, layout=layout)
[docs] @statement(aout=['complex'], ain=['complex'])
def transfer(engine, complex, tf):
complex.apply(lambda k, v: nyquist_mask(tf(k), v) * v, out=Ellipsis)
@transfer.defvjp
def _(engine, tf, _complex):
_complex.apply(lambda k, v: nyquist_mask(numpy.conj(tf(k)), v) * v, out=Ellipsis)
@transfer.defjvp
def _(engine, tf, complex_):
complex_.apply(lambda k, v: nyquist_mask(tf(k), v) * v, out=Ellipsis)
[docs] @statement(aout=['residual'], ain=['model'])
def residual(engine, model, data, sigma, residual):
"""
residual = (model - data) / sigma
J = 1 / sigma
"""
residual[...] = (model - data) / sigma
@residual.defvjp
def _(engine, _model, _residual, data, sigma):
_model[...] = _residual / sigma
@residual.defjvp
def _(engine, model_, residual_, data, sigma):
residual_[...] = model_ / sigma
[docs] @statement(ain=['attribute', 'value'], aout=['attribute'])
def assign_component(engine, attribute, value, dim):
attribute[..., dim] = value
@assign_component.defvjp
def _(engine, _attribute, _value, dim):
_value[...] = _attribute[..., dim]
@assign_component.defjvp
def _(engine, attribute_, value_, dim):
attribute_[..., dim] = value_
[docs] @statement(ain=['x'], aout=['y'])
def assign(engine, x, y):
y[...] = x.copy()
@assign.defvjp
def _(engine, _y, _x):
_x[...] = _y
@assign.defjvp
def _(engine, y_, x_, x):
y_[...] = x.copy()
y_[...][...] = x_
[docs] @statement(ain=['x1', 'x2'], aout=['y'])
def add(engine, x1, x2, y):
y[...] = x1 + x2
@add.defvjp
def _(engine, _y, _x1, _x2):
_x1[...] = _y
_x2[...] = _y
@add.defjvp
def _(engine, y_, x1_, x2_):
y_[...] = x1_ + x2_
[docs] @statement(aout=['y'], ain=['x1', 'x2'])
def multiply(engine, x1, x2, y):
y[...] = x1 * x2
@multiply.defvjp
def _(engine, _x1, _x2, _y, x1, x2):
_x1[...] = _y * x2
_x2[...] = _y * x1
@multiply.defjvp
def _(engine, x1_, x2_, y_, x1, x2):
y_[...] = x1_ * x2 + x1 * x2_
[docs] @statement(ain=['x'], aout=['y'])
def to_scalar(engine, x, y):
if isinstance(x, RealField):
y[...] = x.cnorm()
elif isinstance(x, ComplexField):
raise TypeError("Computing the L-2 norm of complex is not a good idea, because the gradient propagation is ambiguous")
else:
y[...] = engine.pm.comm.allreduce((x[...] ** 2).sum(dtype='f8'))
@to_scalar.defvjp
def _(engine, _y, _x, x):
_x[...] = x * (2 * _y)
@to_scalar.defjvp
def _(engine, y_, x_, x):
if isinstance(x, RealField):
y_[...] = x.cdot(x_) * 2
elif isinstance(x, ComplexField):
raise TypeError("Computing the L-2 norm of complex is not a good idea, because the gradient propagation is ambiguous")
else:
y_[...] = engine.pm.comm.allreduce((x * x_).sum(dtype='f8')) * 2
[docs]def check_grad(code, yname, xname, init, eps, rtol, atol=1e-12, verbose=False):
from numpy.testing import assert_allclose
engine = code.engine
comm = engine.pm.comm
if isinstance(init[xname], numpy.ndarray) and init[xname].shape == engine.q.shape:
cshape = engine.pm.comm.allreduce(engine.q.shape[0]), engine.q.shape[1]
def cperturb(pos, ind, eps):
pos = pos.copy()
start = sum(comm.allgather(pos.shape[0])[:comm.rank])
end = sum(comm.allgather(pos.shape[0])[:comm.rank + 1])
if ind[0] >= start and ind[0] < end:
old = pos[ind[0] - start, ind[1]]
coord = pos[ind[0]-start].copy()
pos[ind[0] - start, ind[1]] = old + eps
new = pos[ind[0] - start, ind[1]]
else:
old, new, coord = 0, 0, 0
diff = comm.allreduce(new - old)
return pos
def cget(pos, ind):
if pos is ZERO: return 0
start = sum(comm.allgather(pos.shape[0])[:comm.rank])
end = sum(comm.allgather(pos.shape[0])[:comm.rank + 1])
if ind[0] >= start and ind[0] < end:
old = pos[ind[0] - start, ind[1]]
else:
old = 0
return comm.allreduce(old)
elif isinstance(init[xname], RealField):
cshape = init[xname].cshape
def cget(real, index):
if real is ZERO: return 0
return real.cgetitem(index)
def cperturb(real, index, eps):
old = real.cgetitem(index)
r1 = real.copy()
r1.csetitem(index, old + eps)
return r1
code = code.copy()
code.to_scalar(x=yname, y='y')
y, tape = code.compute('y', init=init, return_tape=True)
vjp = tape.get_vjp()
jvp = tape.get_jvp()
_x = vjp.compute('_' + xname, init={'_y' : 1.0})
center = init[xname]
init2 = init.copy()
ng_bg = []
fg_bg = []
for index in numpy.ndindex(*cshape):
x1 = cperturb(center, index, eps)
x0 = cperturb(center, index, -eps)
analytic = cget(_x, index)
init2[xname] = x1
y1 = code.compute('y', init2)
init2[xname] = x0
y0 = code.compute('y', init2)
base = (x1 - x0)
y_ = jvp.compute('y_', init={xname + '_': base})
#logger.DEBUG("CHECKGRAD: %s" % (y1, y0, y1 - y0, get_pos(code.engine, _x, index) * 2 * eps))
if verbose:
print(index, (x1 - x0)[...].max(), y, y1 - y0, y_, cget(_x, index) * 2 * eps)
fg_bg.append([index, y_, cget(_x, index) * 2 * eps])
ng_bg.append([index, y1 - y0, cget(_x, index) * 2 * eps])
fg_bg = numpy.array(fg_bg, dtype='O')
ng_bg = numpy.array(ng_bg, dtype='O')
def errorstat(stat, rtol, atol):
g1 = numpy.array([a[1] for a in stat])
g2 = numpy.array([a[2] for a in stat])
ag1 = abs(g1) + (abs(g1) == 0) * numpy.std(g1)
ag2 = abs(g2) + (abs(g2) == 0) * numpy.std(g2)
sig = (g1 - g2) / ((ag1 + ag2) * rtol + atol)
bins = [-100, -50, -20, -1, 1, 20, 50, 100]
d = numpy.digitize(sig, bins)
return d
d1 = errorstat(fg_bg, rtol, atol)
d2 = errorstat(ng_bg, rtol * 10000, atol)
if (d1 != 4).any():
raise AssertionError("FG_BG Bad gradients: %s " % numpy.bincount(d1))
if (d2 != 4).any():
raise AssertionError("NG_BG Bad gradients: %s " % numpy.bincount(d2))