全人類modintを使え!!!

昨日のABC253-E問題でmodを取ったけどバグった!!と言っている人がTLにたくさんいました。今回はそんな人たちに言いたい、modintを使え!!という感じの記事です。
自分で作ったmodintを紹介しますが、そんなん見てられっか!って人はACLにもありますのでそちらを活用してください。

コード

まずは完成形のコードをご覧ください。

template <long long mod>
struct modint {
  modint(ll v = 0) : value(normalize(v)) {}
  ll val() const { return value; }
  void normalize() { value = normalize(value); }
  ll normalize(ll v) {
    if (v <= mod && v >= -mod) {
      if (v < 0) v += mod;
      return v;
    }
    if (v > 0 && v < 2 * mod) {
      v -= mod;
      return v;
    }
    if (v < 0 && v > -2 * mod) {
      v += 2 * mod;
      return v;
    }
    v %= mod;
    if (v < 0) v += mod;
    return v;
  }
  modint<mod>& operator=(ll v) {
    value = normalize(v);
    return *this;
  }
  bool operator==(const modint& o) const { return value == o.val(); }
  bool operator!=(const modint& o) const { return value != o.val(); }
  const modint& operator+() const { return *this; }
  const modint& operator-() const { return value ? mod - value : 0; }
  const modint operator+(const modint& o) const {
    return value + o.val();
  }
  modint& operator+=(const modint& o) {
    value += o.val();
    if (value >= mod) value -= mod;
    return *this;
  }
  const modint operator-(const modint& o) const {
    return value - o.val();
  }
  modint& operator-=(const modint& o) {
    value -= o.val();
    if (value < 0) value += mod;
    return *this;
  }
  const modint operator*(const modint& o) const {
    return (value * o.val()) % mod;
  }
  modint& operator*=(const modint& o) {
    value *= o.val();
    value %= mod;
    return *this;
  }
  const modint operator/(const modint& o) const { return operator*(o.inv()); }
  modint& operator/=(const modint& o) { return operator*=(o.inv()); }
  const modint pow(ll n) const {
    modint ans = 1, x(value);
    while (n > 0) {
      if (n & 1) ans *= x;
      x *= x;
      n >>= 1;
    }
    return ans;
  }
  const modint inv() const {
    ll a = value, b = mod, u = 1, v = 0;
    while (b) {
      ll t = a / b;
      a -= t * b;
      swap(a, b);
      u -= t * v;
      swap(u, v);
    }
    return u;
  }
  friend ostream& operator<<(ostream& os, const modint& x) {
    return os << x.val();
  }
  template <typename T>
  friend modint operator+(T t, const modint& o) {
    return o + t;
  }
  template <typename T>
  friend modint operator-(T t, const modint& o) {
    return -o + t;
  }
  template <typename T>
  friend modint operator*(T t, const modint& o) {
    return o * t;
  }
  template <typename T>
  friend modint operator/(T t, const modint& o) {
    return o.inv() * t;
  }

 private:
  ll value;
};
using modint1000000007 = modint<1000000007>;
using modint998244353 = modint<998244353>;

今回の記事作成にあたって高速化をしました。
私の検証では、ACLのものとほぼ同じかちょっと速い程度になりました。

なんのこっちゃと思った方、簡単に解説を残しておきます。

modint構造体

template<long long mod>
struct modint{}
modint<998244353>のように使うことができるようになります。
structとclassは違いはデフォルトでprivate:かpublic:かだけなのでお好みで
メンバ変数としてvalueを持っています。

コンストラクタ

modint(ll v=0):value(normalize(v))で定義された範囲にvを変えた後にvalueに代入しています。=0を忘れるとvector<mint>を使うときに初期値を0とセットしなければならなかったりします。

正規化

normalize()関数です。modintでは0~mod-1の間しか値を取らないので、それに合うようにしています。
後は適当な高速化です。(きたないです)

演算子オーバーロード(+-*)

operator+とかoperator*とか実装しています。
読めばわかるような内容になっているかなと思います。
modint + intなどは自動でmodint + modint(int)にキャストしてくれるので特別に実装する必要はありません。

逆元(inv)

えっと、自分での実装は無理だったので、おとなしくけんちょんさんの記事を見ましょう!下のpowを使ってreturn pow(mod-2)でもできたりします。

累乗(pow)

繰り返し二乗法を使って高速化します。
繰り返し二乗法も上のけんちょんさんの記事が参考になります(丸投げ)

演算子オーバーロード(/)

前述のinvを用いて除算を行います。
と言っても割る方のinvをかけてあげるだけです。

演算子オーバーロード(+=,-=,*=,/=)

計算をして自身への参照を返してあげるようにします。

friendを用いたオーバーロード

以下は素人の適当な解説です
friend修飾をすることでクラス内でオーバーロードすることができます。
きれいです。

出力

みんな大好きACLにできないことの一つです。
ACLではcout<<mint(15).val()<<endl;のように.val()を付けない場合はコンパイルエラーとなってしまいます。
これをオーバーロードしてcout<<mint(15)<<endl;のようにできるのがこの関数です。

演算子オーバーロード(friend)

これはACLにもありますが、int+mintなどの計算をmint(int)+mintのように変換して計算します。面倒なので交換法則なりで適当にあしらっておきます。

まとめ

雑で、コードを読んでください!な記事でしたが、読んでくださりありがとうございます。結局けんちょんさんの記事が最強なので、みんな読みましょう(?)
modintを使うようになると、DPをするときなどで、除算が―オーバーフローが―とあたふたする必要がなくなります!この利点を活用するためにも、一度自作してみるのも良いと思います。よいmodint-lifeを!

余談

ACLとの違いはほぼ出力の部分といっていいでしょう(実際にはもっと違う部分がありますが実用上の違いとしてです)
それでもワイはACLを使うんじゃぁぁぁって人向けにちょっと残しておきます。

pairの比較関数をオーバーロードするというのはAPG4bにもありますが、それと同じことをします。
あらかじめ
using mint = modint98244353;のように型エイリアスすることは必要になります。

ostream& operator<< (ostream& os,const mint& x){
  return os<<x.val();
}

そこ!!
こっちのほうが便利とか言わない!

いいなと思ったら応援しよう!