Pixel Pedals of Tomakomai

北海道苫小牧市出身の初老の日常

Strassen の行列積を求めるアルゴリズム

アルゴリズムイントロダクションで、行列積を求める奇妙なアルゴリズムを見つけたので rust で実装した。簡単のため、行列の表現には ndarray を用いている。

#[macro_use]
extern crate ndarray;
use ndarray::prelude::*;
type Matrix = Array2<i32>;
type MatrixView<'a> = ArrayView2<'a, i32>;
type MatrixViewMut<'a> = ArrayViewMut2<'a, i32>;

fn strassen_multiply_inner (a: MatrixView, b: MatrixView, mut c: MatrixViewMut) {
    let n = a.rows();

    if n == 1 {
        c[[0, 0]] += a[[0, 0]] * b[[0, 0]];
        return;
    }

    assert_eq!(n & 1, 0);
    let m = n / 2;

    let a11 = a.slice(s![..m, ..m]);
    let a12 = a.slice(s![..m, m..]);
    let a21 = a.slice(s![m.., ..m]);
    let a22 = a.slice(s![m.., m..]);
    let b11 = b.slice(s![..m, ..m]);
    let b12 = b.slice(s![..m, m..]);
    let b21 = b.slice(s![m.., ..m]);
    let b22 = b.slice(s![m.., m..]);

    let s1 = &b12 - &b22;
    let s2 = &a11 + &a12;
    let s3 = &a21 + &a22;
    let s4 = &b21 - &b11;
    let s5 = &a11 + &a22;
    let s6 = &b11 + &b22;
    let s7 = &a12 - &a22;
    let s8 = &b21 + &b22;
    let s9 = &a11 - &a21;
    let s10 = &b11 + &b12;

    let mut p1 = Matrix::zeros((m, m));
    strassen_multiply_inner(a11, MatrixView::from(&s1), MatrixViewMut::from(&mut p1));

    let mut p2 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s2), b22, MatrixViewMut::from(&mut p2));

    let mut p3 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s3), b11, MatrixViewMut::from(&mut p3));

    let mut p4 = Matrix::zeros((m, m));
    strassen_multiply_inner(a22, MatrixView::from(&s4), MatrixViewMut::from(&mut p4));

    let mut p5 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s5), MatrixView::from(&s6), MatrixViewMut::from(&mut p5));

    let mut p6 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s7), MatrixView::from(&s8), MatrixViewMut::from(&mut p6));

    let mut p7 = Matrix::zeros((m, m));
    strassen_multiply_inner(MatrixView::from(&s9), MatrixView::from(&s10), MatrixViewMut::from(&mut p7));

    {
        let mut c11 = c.slice_mut(s![..m, ..m]);
        c11 += &(&p5 + &p4 - &p2 + &p6);
    }
    {
        let mut c12 = c.slice_mut(s![..m, m..]);
        c12 += &(&p1 + &p2);
    }
    {
        let mut c21 = c.slice_mut(s![m.., ..m]);
        c21 += &(&p3 + &p4);
    }
    {
        let mut c22 = c.slice_mut(s![m.., m..]);
        c22 += &(&p5 + &p1 - &p3 - &p7);
    }
}

fn strassen_multiply (x: &Matrix, y: &Matrix) -> Matrix {
    assert_eq!(x.rows(), x.cols());
    assert_eq!(y.rows(), y.cols());
    assert_eq!(x.cols(), y.rows());

    let mut result = Matrix::zeros((x.rows(), y.cols()));

    let x_view = MatrixView::from(x);
    let y_view = MatrixView::from(y);
    let result_view = MatrixViewMut::from(&mut result);

    strassen_multiply_inner(x_view, y_view, result_view);

    result
}

書籍にも自明ではないと書かれている通り、なんとも煩雑で意味不明な計算である。直感的にはキャッシュをうまく使うことで、本来8回再帰しなければならないところを7回に済ませていることで、計算量が n^ {log _ 2 8} のところを n^ {log _ 2 7} で済ませるということになるだろう。

このアルゴリズムは本に書いてあった手順を愚直に実装したのだが、嫌な点がある。 Matrix::zeros や行列の和 &b12 - &b22 などはメモリを確保して行列を作って返す関数なので、この実装だと再帰する度にメモリを確保してしまう。

これを何とかする方法を探していたら、 stackoverflow で C 言語のプログラムが紹介されている のを見つけた。直感的には出力と同じサイズの行列のメモリを確保しておけばいい気はしつつ、再帰しながら使うことを考えると空き領域は計画的に使う必要がある。

fn strassen_multiply_inner2 (a: MatrixView, b: MatrixView, mut c: MatrixViewMut, mut d: MatrixViewMut) {
    assert_eq!(a.rows(), a.cols());
    assert_eq!(b.rows(), b.cols());
    assert_eq!(a.cols(), b.rows());

    let n = a.rows();

    if n == 1 {
        c[[0, 0]] = a[[0, 0]] * b[[0, 0]];
        return;
    }

    assert_eq!(n & 1, 0);
    let m = n / 2;

    let a11 = a.slice(s![..m, ..m]);
    let a12 = a.slice(s![..m, m..]);
    let a21 = a.slice(s![m.., ..m]);
    let a22 = a.slice(s![m.., m..]);
    let b11 = b.slice(s![..m, ..m]);
    let b12 = b.slice(s![..m, m..]);
    let b21 = b.slice(s![m.., ..m]);
    let b22 = b.slice(s![m.., m..]);

    let (c1, c2) = c.split_at(Axis(0), m);
    let (mut c11, mut c12) = c1.split_at(Axis(1), m);
    let (mut c21, mut c22) = c2.split_at(Axis(1), m);

    let (d1, d2) = d.split_at(Axis(0), m);
    let (mut d11, mut d12) = d1.split_at(Axis(1), m);
    let (mut d21, mut d22) = d2.split_at(Axis(1), m);

    d11.assign(&a12);
    d11 -= &a22;
    d12.assign(&b21);
    d12 += &b22;
    strassen_multiply_inner2(d11.view(), d12.view(), c11.view_mut(), d21.view_mut());

    d11.assign(&a21);
    d11 -= &a11;
    d12.assign(&b11);
    d12 += &b12;
    strassen_multiply_inner2(d11.view(), d12.view(), c22.view_mut(), d21.view_mut());

    d11.assign(&a11);
    d11 += &a12;
    strassen_multiply_inner2(d11.view(), b22.view(), c12.view_mut(), d12.view_mut());

    c11 -= &c12;
    d11.assign(&b21);
    d11 -= &b11;
    strassen_multiply_inner2(a22.view(), d11.view(), c21.view_mut(), d12.view_mut());

    c11 += &c21;
    d11.assign(&b12);
    d11 -= &b22;
    strassen_multiply_inner2(a11.view(), d11.view(), d12.view_mut(), d21.view_mut());

    c12 += &d12;
    c22 += &d12;
    d11.assign(&a21);
    d11 += &a22;
    strassen_multiply_inner2(d11.view(), b11.view(), d12.view_mut(), d21.view_mut());

    c21 += &d12;
    c22 -= &d12;
    d11.assign(&a11);
    d11 += &a22;
    d12.assign(&b11);
    d12 += &b22;
    strassen_multiply_inner2(d11.view(), d12.view(), d21.view_mut(), d22.view_mut());

    c11 += &d21;
    c22 += &d21;
}

fn strassen_multiply2 (x: &Matrix, y: &Matrix) -> Matrix {
    assert_eq!(x.rows(), x.cols());
    assert_eq!(y.rows(), y.cols());
    assert_eq!(x.cols(), y.rows());

    let mut result = Matrix::zeros((x.rows(), y.cols()));
    let mut tmp = Matrix::zeros((x.rows(), y.cols()));

    strassen_multiply_inner2(x.view(), y.view(), result.view_mut(), tmp.view_mut());

    result
}

d を一時領域として使うのだが、再帰の度に d11 d12 d21 d22 の4つの領域を順に使っている。最大で必要な一時領域の数は 3 つ(最後に現れる strassen_multiply_inner2(&MatrixView::from(&d11), &MatrixView::from(&d12), &mut d21, &mut d22);再帰)なので、これでうまく実装できる。しかし、これまた最初の数式から違っていて、なんで行列の積が求められているのか意味不明である。 そして、実装の際に所有権と MatrixView(|Mut) の使い方でハマったので、 MatrixView(|Mut)::from&mut をそこらにつけて誤魔化している。もしかすると、もっと簡潔な記述法はあるかもしれない。 viewview_mut で良さそうだったので書き換えておいた。

rustのsliceとlen()

こう書くと、

fn main() {
    let mut a = [1, 2, 3, 4, 5];
    let x = &mut a[3..a.len()];
    println!("{:?}", x);
}

こう怒られる。

error[E0502]: cannot borrow `a` as immutable because it is also borrowed as mutable
 --> src/main.rs:3:23
  |
3 |     let x = &mut a[3..a.len()];
  |                  -----^-------
  |                  |    |
  |                  |    immutable borrow occurs here
  |                  mutable borrow occurs here
  |                  mutable borrow later used here

こうなら良い。

fn main() {
    let mut a = [1, 2, 3, 4, 5];
    let n = a.len();
    let x = &mut a[3..n];
    println!("{:?}", x);
}

3..a.len() を評価してから a の可変参照を取るならなんの問題もないような気もするのだが、実際このコードは、

use std::ops::IndexMut;

fn main() {
    let mut a = [1, 2, 3, 4, 5];
    let x = a.index_mut(3..a.len());
    println!("{:?}", x);
}

と同値であり、 &mut aa.len()fn index_mut<'a>(&'a mut self, index: Side) -> &'a mut Self::Output の引数となる。第1引数で可変参照を渡そうとしているのに第2引数で不変参照を用いているので怒られる。

error[E0502]: cannot borrow `a` as immutable because it is also borrowed as mutable
 --> src/main.rs:5:28
  |
5 |     let x = a.index_mut(3..a.len());
  |             - ---------    ^ immutable borrow occurs here
  |             | |
  |             | mutable borrow later used by call
  |             mutable borrow occurs here

まあそもそも、このサンプルのように末尾まで取りたいのなら、以下のように書くのが正しい。

fn main() {
    let mut a = [1, 2, 3, 4, 5];
    let x = &mut a[3..];
    println!("{:?}", x);
}

a.len() - 1 まで取りたいのなら分けて書かないと駄目なのかな。

rustで「サイコロを転がす」を解く

最近 rust を勉強し始めたので、オフラインリアルタイムどう書くで一番簡単な問題(と思っている) サイコロを転がす を解いた。破壊的更新を戦略的に使えるのは、この程度の難易度のサンプルであれば書いてて大変気持ちがいい。

yhpg-rs/lib.rs at ord12rotdice · hiratara/yhpg-rs · GitHub

rust の perl-xs を触る

rust の perl-xs なるリポジトリを見つけたので触ってみた。 Perlcarton と rust の cargo が動く環境1であれば、 README に書かれている通りリポジトリを clone してきて以下で簡単に試せる。

$ carton install
$ carton exec -- 'cd t && perl Makefile.PL && make test'

perl-xsPerl API を rust から使いやすいようにラップしたものという位置づけになる。 Perl API への低レベルなバインディングperl-sys で提供され、 Rust のコードのビルドは Module::Install::Rust で提供される。また、 perl-sys で C のマクロに相当する処理の rust のバインディングを提供するために、内部で Ouroboros を使っている。

perl-xs を使うと、 XS や C を書かなくても rust だけで Perl のモジュールが書ける。ただし、このクレートは発展途上であり、足りない処理も多い。例えば、今日現在のバージョンでは Perlvalues に相当する処理がないので、作ってみる。

src/hash.rs に以下のようにイテレータを実装する。

impl HV {
    ...

    #[inline]
    pub fn iter_values(&self) -> IterHVVal {
        IterHVVal::new(self)
    }

    ...
}

pub struct IterHVVal<'a>(&'a HV);

impl<'a> IterHVVal<'a> {
    fn new(hv: &'a HV) -> Self {
        unsafe { hv.pthx().hv_iterinit(hv.as_ptr()) };
        IterHVVal(hv)
    }
}

impl<'a> Iterator for IterHVVal<'a> {
    type Item = SV;

    fn next(&mut self) -> Option<Self::Item> {
        unsafe {
            let hv = &self.0;
            let pthx = hv.pthx();
            let hv_ptr = hv.as_ptr();
            let he = pthx.hv_iternext(hv_ptr);
            if he.is_null() {
                None
            } else {
                let sv = pthx.hv_iterval(hv_ptr, he);
                Some(SV::from_sv(pthx, sv))
            }
        }
    }
}

そして、これを利用して Perl のモジュールを書く。以下のような lib.rs を書いた。ハッシュの値の二乗の和を取るだけのシンプルな関数である。

#[macro_use]
extern crate perl_xs;
#[macro_use]
extern crate perl_sys;

mod xstest {
    use perl_xs::raw::NV;
    use perl_xs::HV;

    xs! {
        package XSTest;

        sub sum_values(ctx, hv: HV) {
            let n: NV = hv.iter_values().map(|sv| {
                let n = sv.nv();
                n * n
            }).sum();
            ctx.new_sv(n)
        }
    }
}

xs! {
    bootstrap boot_XSTest;
    use xstest;
}

そして、 t/ ディレクトリを参考に Makefile.PLlib/XSTest.pm を置けばそれだけで完成。なんともお手軽である。

$ carton exec -- 'cd my_example && perl Makefile.PL && make && perl -Mblib -E "use XSTest; say XSTest::sum_values({a => 1.0, b => 2.0, c => 3.0})"'

...

14

さて、 perl-xs クレートはこのように大変お手軽でいいのだが、 作者のスライド でも触れられているように残念ながらパフォーマンスに問題がある。手元でもベンチマークを取ってみたが、先程の sum_values は Pure Perl と比べて 30% しか性能を改善できなかった 2 。同じ処理を素の XS で書いたところ 350% 性能が改善したので、ほんとに遅いとしか言いようがない。

原因を探るべく、 perl-sys 側のコードを見てみると、 PerlAPIJMPENV_* でラップされていた。以下はビルド中に自動生成された perl_sys.c からの抜粋である。

int perl_sys_hv_iternext(HE** RETVAL, HV * hv) {
    int rc = 0;
    dJMPENV;
    JMPENV_PUSH(rc);
    if (rc == 0) { *RETVAL = hv_iternext(hv); }
    JMPENV_POP;
    return rc;
}

これらを取り除いたところ、 Pure Perl と比べて 100% 性能が改善するようになった。他、余計なコピー処理などを省いて 200% 程度まで改善できたが、キリがないのでそこで辞めておいた。徹底的に省けば、素の XS と同程度まで改善できるだろう。

なお、イテレータloop に変えると 20% 程度性能が改善したが、この程度の違いで済むのはさすが rust だなと思った。


  1. WSLで試した。

  2. 驚くべきことに、要素数を 10 万個まで増やすと Pure Perl に負けてしまった。

トレイト境界の F: FnMut(u8) -> bool という記法について

rust で where の中でトレイト境界を書ける。そして、クロージャはトレイトで実現されている。ってことなんだけど、クロージャを表す型変数のトレイト境界の書き方はなんとも気持ちが悪い。

struct Cacher<T>
    where T: Fn(u32) -> u32
{
    calculation: T,
    value: Option<u32>,
}

Closures: Anonymous Functions that Can Capture Their Environment - The Rust Programming Language

トレイト境界と言えば T: Display + Clone とか、ジェネリックであれば <> を使って T: AsRef<str> みたいな記法になるはずだが、 () とか -> とはなんなのか。

答えから言うと、これは Fn ファミリー独自の記法のようだ。文法的にはこうなっている。 ->() を使う規則は TypePathFn命名されている。

TypePathFn :
( TypePathFnInputs? ) (-> Type)?

TypePathFnInputs :
Type (, Type)* ,?

Paths - The Rust Reference

これは Foo<(...), Output=B> の糖衣だとする文章も残っている。

rust-rfcs/0587-fn-return-should-be-an-associated-type.md at master · nox/rust-rfcs · GitHub

ここまで来ると、ああ、やっぱりクロージャはトレイトなんだなという思いになってくるが、現行ではこの規則で脱糖してもコンパイルすることはできず、 TypePathFn の記法を使わねばならない。

error[E0658]: the precise format of `Fn`-family traits' type parameters is subject to change. Use parenthetical notation (Fn(Foo, Bar) -> Baz) instead (see issue #29625)
   --> src/main.rs:175:10
    |
175 | where F: FnMut<(u8,), Output=bool>
    |          ^^^^^^^^^^^^^^^^^^^^^^^^^

VSCodeからWSL上のRLSを使うメモ (2)

ある朝突然、 cargo testVSCode Insiders から呼べなくなった。これを踏んでた。どうやら昨日の夕方に 1.36.0-insider がアップデートされて壊れたようだ(アップデートされた記憶はあまりない)。

github.com

以下に従って5月のビルド 1.35.0-insider に戻したら問題なく動いた。 Insider らしくてとてもよい(よくない)。

Access older Insider builds · Issue #46089 · microsoft/vscode · GitHub

VSCodeからWSL上のRLSを使うメモ

悲しいことにうまく動せていない。時間が解決してくれる気はする。

その1. Remote WSL環境

開発者版 VSCode insiders が必要。 VSCode とは別のアプリとしてインストールできるので気にせず入れて良い。Remote WSL 拡張を入れればそれで終わり。

その2 で使う useWSL は使わない。ほぼうまく動くが、なぜか rustfmt による整形が動かない上に、 Output View へログを一行も吐いてくれない*1ので調査もできなくて困る。

その2. RLS拡張のuseWSLを使う

RLS拡張 が想定しているのはおそらくこちらの使い方。

しかし、まず rustup の PATH が通らない。 issue 上がってそう だけど放置されてる感。拡張の設定で rustup の PATH を WSL 内の full path にすればこの問題は対処できる。

次に嫌なのが、 WSL 内の windowsファイルシステムのマウントが /mnt/固定であること/etc/wsl.confroot 設定を書いていると動かせない。

root を変えるのが面倒でここで諦めてしまったが、使っている人がいる機能であるはずなので多分動かせば動くのであろう root を変えたら動いた。あと、当たり前ではあるがこちらだと、 windowsファイルシステムにプロジェクトを置かなければならないという制約がある。

*1:その2でもきちんと動くようになったあとはログが出なくなったので、エラーが起きないとログは吐かないという話のような気もする