AtCoder ABC293 E - Geometric Progression
※ a ≦ i ≦ bにおけるA^iの和をsum(a, b)と表現することにする
例えば2から5までのA^iの和は
A^2 + A^3 + A^4 + A^5 = sum(2, 5)
と表現する
考えたこと
・Xが2のN乗なら簡単
例えばX = 8ならsum(0, 7)を求めればいい。このとき
A^4 = A^0 * A^4
A^5 = A^1 * A^4
A^6 = A^2 * A^4
A^7 = A^3 * A^4
なので、4から7までの和sum(4, 7)は
sum(4, 7) = A^4 * (A^0 + A^1 + A^2 + A^3)
= A^4 * sum(0, 3)
で求められる。よってsum(0, 7)は
sum(0, 7) = sum(0, 3) + sum(4, 7)
= sum(0, 3) + A^4 * sum(0, 3)
となる。同様にsum(0, 3), sum(0, 1)も
sum(0, 3) = sum(0, 1) + sum(2, 3)
= sum(0, 1) + A^2 * sum(0, 1)
sum(0, 1) = sum(0, 0) + sum(1, 1)
= sum(0, 0) + A^1 * sum(0, 0)
と求めることができる。sum(0, 0)はA^0で1なので、sum(0, 7)を求める式中のすべての項をA^iで表現することができた。この方法なら計算量はO(logX)で済む。
・端数をどうするか
X = 11の場合を考える。sum(0, 10)を求めればいいが
sum(0, 10) = sum(0, 7) + sum(8, 10)
となるので、sum(8, 11)の部分が求められればいいことになる。ここで上記の2のN乗のときと同様に
A^8 = A^0 * A^8
A^9 = A^1 * A^8
A^10 = A^2 * A^8
と考えると、sum(8, 10)は
sum(8, 10) = A^8 * (A^0 + A^1 + A^2)
= A^8 * sum(0, 2)
となる。あとは同様に
sum(0, 2) = sum(0, 1) + sum(2, 2)
sum(2, 2) = A^2 * sum(0, 0)
というように、2のN乗問題に帰着させることができる。
sum(0, 10)を求めるためにsum(0, 7)とsum(0, 2)が必要で、
sum(0, 2)を求めるためにsum(0, 0)とsum(0, 1)が必要で、
sum(0, 1)を求めるためにsum(0, 0)が必要。
という格好になっているので、再帰関数で書くと楽そうである。
・どんな再帰にするか
上記の通りsum(0, 10)を求めるためにsum(0, 7)とsum(0, 2)に分ければいいので、再帰関数にxが入力されたとすると次のような処理をすればいい。
(1) xを超えない範囲で最大の2のN乗を求める
ここで求めた2のN乗をeと呼ぶことにする
(2) e - 1を入力に与えて再帰関数を実行する
sum(0, 10)におけるsum(0, 7)側の処理にあたる
(3) x - eを入力に与えて再帰関数を実行する
sum(0, 10)におけるsum(0, 2)側の処理にあたる
(4) (2)の結果 + (3)の結果 × A^eを返す
こんな感じでいい。何度も同じ入力で再帰関数を実行することになるので、一度計算した結果は関数の外で定義した配列に保存しておいて、二度目の実行時には配列を参照するだけにするといい。(いわゆるメモ化再帰)
書いたコードと提出結果
#include <bits/stdc++.h>
std::map< long long, long long > memo;
std::map< long long, long long > twoN;
long long mod;
// 0からendまでの和を求める
long long sum(long long end){
// endが0ならば0から0までの和になるので、Aの0乗のみである
if(end == 0){
return 1ll;
}
// すでに計算済みならば同じ計算を何度もやる必要はないので、覚えている結果を返す
if(memo.find(end) != memo.end()){
return memo[end];
}
// endを超えない範囲で最大の2のN乗を求める
long long lbit = 60ll;
while( ( (1ll << lbit) & end ) == 0){
lbit--;
}
long long e = (1ll << lbit);
long long sum1 = sum(end - e);
long long sum2 = sum(e - 1ll);
long long ret = ( (sum1 * twoN[e]) % mod + sum2) % mod;
// 後で使うために計算結果を覚えておく
memo[end] = ret;
return ret;
}
int main(){
long long A, X, M;
std::cin >> A >> X >> M;
mod = M;
if(M == 1){
std::cout << 0 << std::endl;
return 0;
}
// Aの1乗, 2乗, 4乗, 8乗, …は何度も使うのであらかじめ計算しておく
// Aの2N乗 = AのN乗 ×AのN乗
twoN[0] = 1ll;
twoN[1] = A;
for(int i=1; i<=60; i++){
twoN[1ll << i] = (twoN[1ll << (i-1)] * twoN[1ll << (i-1)]) % M;
}
std::cout << sum(X-1) << std::endl;
return 0;
}
終わりに
再帰関数って書くの難しいよね。仕事じゃめったに書かないし。