北海道苫小牧市出身の初老PGが書くブログ

永遠のプログラマを夢見る、苫小牧市出身のおじさんのちらしの裏

Strassen の行列積を求めるアルゴリズム

アルゴリズムイントロダクションで、行列積を求める奇妙なアルゴリズムを見つけたので rust で実装した。簡単のため、行列の表現には ndarray を用いている。

#[macro_use]
extern crate ndarray;
use ndarray::prelude::*;
type Matrix = Array2<i32>;
type MatrixView<'a> = ArrayView2<'a, i32>;
type MatrixViewMut<'a> = ArrayViewMut2<'a, i32>;

fn strassen_multiply_inner (a: MatrixView, b: MatrixView, mut c: MatrixViewMut) {
    let n = a.rows();

    if n == 1 {
        c[[0, 0]] += a[[0, 0]] * b[[0, 0]];
        return;
    }

    assert_eq!(n & 1, 0);
    let m = n / 2;

    let a11 = a.slice(s![..m, ..m]);
    let a12 = a.slice(s![..m, m..]);
    let a21 = a.slice(s![m.., ..m]);
    let a22 = a.slice(s![m.., m..]);
    let b11 = b.slice(s![..m, ..m]);
    let b12 = b.slice(s![..m, m..]);
    let b21 = b.slice(s![m.., ..m]);
    let b22 = b.slice(s![m.., m..]);

    let s1 = &b12 - &b22;
    let s2 = &a11 + &a12;
    let s3 = &a21 + &a22;
    let s4 = &b21 - &b11;
    let s5 = &a11 + &a22;
    let s6 = &b11 + &b22;
    let s7 = &a12 - &a22;
    let s8 = &b21 + &b22;
    let s9 = &a11 - &a21;
    let s10 = &b11 + &b12;

    let mut p1 = Matrix::zeros((m, m));
    strassen_multiply_inner(a11, MatrixView::from(&s1), MatrixViewMut::from(&mut p1));

    let mut p2 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s2), b22, MatrixViewMut::from(&mut p2));

    let mut p3 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s3), b11, MatrixViewMut::from(&mut p3));

    let mut p4 = Matrix::zeros((m, m));
    strassen_multiply_inner(a22, MatrixView::from(&s4), MatrixViewMut::from(&mut p4));

    let mut p5 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s5), MatrixView::from(&s6), MatrixViewMut::from(&mut p5));

    let mut p6 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s7), MatrixView::from(&s8), MatrixViewMut::from(&mut p6));

    let mut p7 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s9), MatrixView::from(&s10), MatrixViewMut::from(&mut p7));

    {
        let mut c11 = c.slice_mut(s![..m, ..m]);
        c11 += &(&p5 + &p4 - &p2 + &p6);
    }
    {
        let mut c12 = c.slice_mut(s![..m, m..]);
        c12 += &(&p1 + &p2);
    }
    {
        let mut c21 = c.slice_mut(s![m.., ..m]);
        c21 += &(&p3 + &p4);
    }
    {
        let mut c22 = c.slice_mut(s![m.., m..]);
        c22 += &(&p5 + &p1 - &p3 - &p7);
    }
}

fn strassen_multiply (x: &Matrix, y: &Matrix) -> Matrix {
    assert_eq!(x.rows(), x.cols());
    assert_eq!(y.rows(), y.cols());
    assert_eq!(x.cols(), y.rows());

    let mut result = Matrix::zeros((x.rows(), y.cols()));

    let x_view = MatrixView::from(x);
    let y_view = MatrixView::from(y);
    let result_view = MatrixViewMut::from(&mut result);

    strassen_multiply_inner(x_view, y_view, result_view);

    result
}

書籍にも自明ではないと書かれている通り、なんとも煩雑で意味不明な計算である。直感的にはキャッシュをうまく使うことで、本来8回再帰しなければならないところを7回に済ませていることで、計算量が n^ {log _ 2 8} のところを n^ {log _ 2 7} で済ませるということになるだろう。

このアルゴリズムは本に書いてあった手順を愚直に実装したのだが、嫌な点がある。 Matrix::zeros や行列の和 &b12 - &b22 などはメモリを確保して行列を作って返す関数なので、この実装だと再帰する度にメモリを確保してしまう。

これを何とかする方法を探していたら、 stackoverflow で C 言語のプログラムが紹介されている のを見つけた。直感的には出力と同じサイズの行列のメモリを確保しておけばいい気はしつつ、再帰しながら使うことを考えると空き領域は計画的に使う必要がある。

fn strassen_multiply_inner2 (a: MatrixView, b: MatrixView, mut c: MatrixViewMut, mut d: MatrixViewMut) {
    assert_eq!(a.rows(), a.cols());
    assert_eq!(b.rows(), b.cols());
    assert_eq!(a.cols(), b.rows());

    let n = a.rows();

    if n == 1 {
        c[[0, 0]] = a[[0, 0]] * b[[0, 0]];
        return;
    }

    assert_eq!(n & 1, 0);
    let m = n / 2;

    let a11 = a.slice(s![..m, ..m]);
    let a12 = a.slice(s![..m, m..]);
    let a21 = a.slice(s![m.., ..m]);
    let a22 = a.slice(s![m.., m..]);
    let b11 = b.slice(s![..m, ..m]);
    let b12 = b.slice(s![..m, m..]);
    let b21 = b.slice(s![m.., ..m]);
    let b22 = b.slice(s![m.., m..]);

    let (c1, c2) = c.split_at(Axis(0), m);
    let (mut c11, mut c12) = c1.split_at(Axis(1), m);
    let (mut c21, mut c22) = c2.split_at(Axis(1), m);

    let (d1, d2) = d.split_at(Axis(0), m);
    let (mut d11, mut d12) = d1.split_at(Axis(1), m);
    let (mut d21, mut d22) = d2.split_at(Axis(1), m);

    d11.assign(&a12);
    d11 -= &a22;
    d12.assign(&b21);
    d12 += &b22;
    strassen_multiply_inner2(d11.view(), d12.view(), c11.view_mut(), d21.view_mut());

    d11.assign(&a21);
    d11 -= &a11;
    d12.assign(&b11);
    d12 += &b12;
    strassen_multiply_inner2(d11.view(), d12.view(), c22.view_mut(), d21.view_mut());

    d11.assign(&a11);
    d11 += &a12;
    strassen_multiply_inner2(d11.view(), b22.view(), c12.view_mut(), d12.view_mut());

    c11 -= &c12;
    d11.assign(&b21);
    d11 -= &b11;
    strassen_multiply_inner2(a22.view(), d11.view(), c21.view_mut(), d12.view_mut());

    c11 += &c21;
    d11.assign(&b12);
    d11 -= &b22;
    strassen_multiply_inner2(a11.view(), d11.view(), d12.view_mut(), d21.view_mut());

    c12 += &d12;
    c22 += &d12;
    d11.assign(&a21);
    d11 += &a22;
    strassen_multiply_inner2(d11.view(), b11.view(), d12.view_mut(), d21.view_mut());

    c21 += &d12;
    c22 -= &d12;
    d11.assign(&a11);
    d11 += &a22;
    d12.assign(&b11);
    d12 += &b22;
    strassen_multiply_inner2(d11.view(), d12.view(), d21.view_mut(), d22.view_mut());

    c11 += &d21;
    c22 += &d21;
}

fn strassen_multiply2 (x: &Matrix, y: &Matrix) -> Matrix {
    assert_eq!(x.rows(), x.cols());
    assert_eq!(y.rows(), y.cols());
    assert_eq!(x.cols(), y.rows());

    let mut result = Matrix::zeros((x.rows(), y.cols()));
    let mut tmp = Matrix::zeros((x.rows(), y.cols()));

    strassen_multiply_inner2(x.view(), y.view(), result.view_mut(), tmp.view_mut());

    result
}

d を一時領域として使うのだが、再帰の度に d11 d12 d21 d22 の4つの領域を順に使っている。最大で必要な一時領域の数は 3 つ(最後に現れる strassen_multiply_inner2(&MatrixView::from(&d11), &MatrixView::from(&d12), &mut d21, &mut d22);再帰)なので、これでうまく実装できる。しかし、これまた最初の数式から違っていて、なんで行列の積が求められているのか意味不明である。 そして、実装の際に所有権と MatrixView(|Mut) の使い方でハマったので、 MatrixView(|Mut)::from&mut をそこらにつけて誤魔化している。もしかすると、もっと簡潔な記述法はあるかもしれない。 viewview_mut で良さそうだったので書き換えておいた。