NumPy で畳込み層(torch.nn.functional.conv2d 互換)を書いたメモです.解説を書く元気が残ってないので,興味がある人はがんばって解読してください.適当なパラメータで PyTorch と答えが allclose で一致することを確認してますが,網羅的にはテストしていないです.PyTorch の方がずっと速いので,これ自体は実用的ではないですが,NumPy 使ってスクラッチからニューラルネットや DL フレームワーク書きたい人には取っかかりとして手っ取り早いのではないかと思います.

# Copyright 2021 Seiya Tokui
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy


def conv2d(x, w, bias=None, stride=1, padding=0, dilation=1, groups=1):
    """PyTorch-compatible simple implementation of conv2d in pure NumPy.

    NOTE: This code prioritizes simplicity, sacrificing the performance.
    It is much faster to use matmul instead of einsum (at least with NumPy 1.19.1).

    MIT License.

    """
    sY, sX = stride if isinstance(stride, (list, tuple)) else (stride, stride)
    pY, pX = padding if isinstance(padding, (list, tuple)) else (padding, padding)
    dY, dX = dilation if isinstance(dilation, (list, tuple)) else (dilation, dilation)
    N, iC, iH, iW = x.shape
    oC, iCg, kH, kW = w.shape
    pY_ex = (sY - (iH + pY * 2 - (kH - 1) * dY) % sY) % sY
    pX_ex = (sX - (iW + pX * 2 - (kW - 1) * dX) % sX) % sX
    oH = (iH + pY * 2 + pY_ex - (kH - 1) * dY) // sY
    oW = (iW + pX * 2 + pX_ex - (kW - 1) * dX) // sX

    x = numpy.pad(x, ((0, 0), (0, 0), (pY, pY + pY_ex), (pX, pX + pX_ex)))
    sN, sC, sH, sW = x.strides
    col = numpy.lib.stride_tricks.as_strided(
        x, shape=(N, groups, iCg, oH, oW, kH, kW),
        strides=(sN, sC * iCg, sC, sH * sY, sW * sX, sH * dY, sW * dX),
    )
    w = w.reshape(groups, oC // groups, iCg, kH, kW)
    y = numpy.einsum('ngihwkl,goikl->ngohw', col, w).reshape(N, oC, oH, oW)
    if bias is not None:
        y += bias[:, None, None]
    return y

やってることは im2col による実装と同じです.Chainer のときは col 相当のものをループ書いて作ってましたが,ループなし・コピーなしで書けるじゃん,というのが今回の気づきです.

  • as_strided で移動窓を作っています (cf. sliding_window_view).
  • stridedilation はどちらも strides で対応できます.
  • padding はおそらくどうしようもない.
  • einsumreshape/transpose がんばれば matmul になりますがシンプルさを優先しました.大抵の場合は matmul の方が断然速いです(Chainer の grouped conv 実装は matmul を使っていて,ちゃんと速い).einsum 爆速になってほしい.
    • 追記: optimize=True のことを忘れてたので試してみましたが,特に速度は変わりませんでした.残念.
  • backward も似た方法で書けるはず(einsum の微分は einsum で,as_strided は出力を as_strided で作って add.at).
  • as_stridedcupy のものに置き換えれば CuPy で動きます.ちゃんと比べてませんが,こっちは実装の割にそこそこ速いみたいです.適当に P100 のマシンで比べると,autotune してない PyTorch と比べて 1.5〜2 倍くらいの時間で動いてそうです.cuDNN 叩くよりずっと簡単なのでいいですね.