剰余の世界での四則演算

[latexpage]

Contents

概要

除算だけ難しいです。

演算

加算

足してmod Mするだけ。

減算

引き算するだけですが、マイナスにならないようにMを足してから剰余演算をしましょう。

(a-b+M)%M

乗算

かけてmod Mするだけ。
ただし、オーバーフローさせないために乗算する毎に剰余演算をする必要があることに注意。

除算

除算時は逆元を用います。

$$a/b \equiv c\quad (mod M)$$

を考えるとき、bの逆元$b^{-1}$は下式の特性を持ちます。

$$b \times b^{-1} \equiv 1\quad (mod M)$$

このとき、aと$b^{-1}$を掛け合わせれば、mod Mの世界ではa/bと等価となります。

(参考: 「1000000007 で割ったあまり」の求め方を総特集! 〜 逆元から離散対数まで 〜 – Qiita)

逆元の計算方法は

  • フェルマーの小定理を用いた方法
  • 拡張ユークリッドの互除法を用いた方法

の2種類がありますが、拡張ユークリッドの互除法のほうが高速、かつ制約が緩くなります。

詳しい理論・実装については下記ページのpdfを参考にしてください。

逆元の存在する条件

bとMが互いに素である時に逆元が存在します。
ですので、Mが素数でbがそれ未満の値であるときは基本的に安心です。

ちなみに、$10^9 + 7$ で割った余りを求めろ、といった問題がよく出題されますが、何故でしょうか。
それは、$10^9 + 7$ が32bitの範囲で表せる、ちょうどよい値の素数だからです。

累乗

愚直に$x^{n}$を求めると$O(n)}$となってしまうので、2分累乗法を用います。
(8乗は2乗の2乗の2乗…的な)

// x^n(mod M)
ll mod_pow(ll x, ll n, ll m) {
  ll ans = 1;
  while (n != 0) {
    if (n & 1)
      ans = ans * x % m;
    x = x * x % m;
    n = n >> 1;
  }
  return ans;
}

このとき、計算量は$O(log(n))$ となります。

modint構造体

上記の演算方法で四則演算が全てできるようになったわけですが、各計算を行うごとに%modの演算をしないといけないのは億劫です。

特に、乗算・除算の場合は1回演算子を呼ぶ毎にやらないとオーバーフローのリスクがあります。
(a*b*c ではなく(a*b%mod)*c%modとなります!)

この問題を解決するため、modint 構造体を定義します。
演算子オーバーロードによって毎演算ごとに自動的に剰余演算を行ってくれます!便利!

AtCoder Live(AtCoderの公式解説YouTubeチャンネル)で紹介されたサンプルコードを参考に実装してみました。

using ll = long long;

template <ll ModVal>
struct ModInt {
  ll x;

  ModInt(ll _x = 0) : x((_x % ModVal + ModVal) % ModVal) {
  }

  ModInt operator-() const {
    return ModInt(-x);
  }
  ModInt& operator+=(const ModInt a) {
    x += a.x;
    if (x >= ModVal)
      x -= ModVal;
    return *this;
  }
  ModInt& operator-=(const ModInt a) {
    x = x + ModVal - a.x;
    if (x >= ModVal)
      x -= ModVal;
    return *this;
  }
  ModInt& operator*=(const ModInt a) {
    x *= a.x;
    x %= ModVal;
    return *this;
  }

  ll ext_gcd(ll a, ll b, ll& x, ll& y) {
    if (b == 0) {
      x = 1;
      y = 0;
      return a;
    }
    ll tmp = a / b;
    ll d = ext_gcd(b, a - b * tmp, y, x);
    y -= tmp * x;
    return d;
  }

  // 逆元
  ModInt inv(const ModInt a) {
    ll u, v;
    ext_gcd(a.x, ModVal, u, v);
    return ModInt(u);
  }

  ModInt& operator/=(const ModInt a) {
    return (*this) *= inv(a);
  }

  ModInt operator+(const ModInt a) const {
    ModInt retval(*this);
    return retval += a;
  }
  ModInt operator-(const ModInt a) const {
    ModInt retval(*this);
    return retval -= a;
  }
  ModInt operator*(const ModInt a) const {
    ModInt retval(*this);
    return retval *= a;
  }
  ModInt operator/(const ModInt a) const {
    ModInt retval(*this);
    return retval /= a;
  }

  ModInt pow(ll n) {
    ModInt ans(1);
    while (n) {
      if (n & 1)
        ans = ans * x;
      *this = (*this) * (*this);
      n = n >> 1;
    }
    return ans;
  }

  constexpr const ll& value() {
    return this->x;
  }
};

template <ll ModVal>
ostream& operator<<(ostream& os, const ModInt<ModVal>& a) {
  os << a.x;
  return os;
}

#define mod (ll)(1e9 + 7)
using mint = ModInt<mod>;

参考サイト

関連問題