見出し画像

【ABC340】緑コーダーの競プロ参加記録 #2

「数学的な正しさより雰囲気のつかみやすさを優先した灰コーダー茶コーダー向けの記事があったらなぁ」

本記事は、そんな茶コーダーの頃の私の気持ちをベースにしています。
いつか別記事として分けたい内容もしばらくの間は毎度書く予定なので、冗長に感じられましたら申し訳ございません。

今回はAtCoder Beginner Contest 340。使用言語はPythonです。


【解説】 A問題 - Arithmetic Progression

コンテスト中の提出
https://atcoder.jp/contests/abc340/submissions/50140823


  • 整数$${A、B、D}$$が与えられる。

  • 初項$${A}$$、末項$${B}$$、公差$${D}$$の等差数列を出力してね。


等差数列。
A問題で数学用語が説明もなく現れました。びっくり。

ABCのA問題では、プログラミング言語の基礎的な処理を問われることが多いです。この問題も例にもれず、入力・出力、ループ処理の使い方が問われています。
まずは入力と出力から見ていきましょう。

Pythonではinput関数を利用することで入力を受け取ることができます。そしてinput関数は入力1行をまるまる文字列として受け取ります。今回は与えられる値が整数 A、B、D の3つなので工夫が必要です(後述)。

出力はprint関数を利用します。ABCではprint関数で出力された値をユーザーの答えとして扱うので、必要以上にprint関数を実行して不正解にならないよう気を付けましょう。

""" 入力を受け取る """
S = input()

""" 問題を解く処理 """

""" 答えを出力 """
print(ans)


今回の入力形式に慣れている方はここから2000文字くらい読み飛ばしてもらって構いません。


さて、工夫が必要な整数$${A、B、D}$$の受け取り方ですが、結論から言うと次のコードになります。ほとんどの競プロPythonユーザーはこう書いていると思います。

A, B, D = map(int, input().split())

では、上のコードになるまで流れを1ステップずつ考えていきましょう。

くり返しになりますが、input関数は入力1行をまるまる文字列として受け取ります。整数、リストとして入力を受け取りたい場合でもinput関数で1つの文字列として受け取ることからスタートする点で変わりはありません。

""" 入力を受け取る """
S = input()

print(S)
"""
入力例 1 
>>> '3 9 2'
"""

今回はinput関数で受け取った文字列'3 9 2'」整数「3」「9」「2」に分解したいです。
forループで空白ではない文字だけを取ろうとすると2桁以上の数字を破壊してしまうかもしれません。Pythonには「文字列を特定の文字列で区切ってリストに変換する」split関数なるものが存在するのでこれを利用します。

今回は空白区切りをしたいのでsplit関数の引数に何も指定しなくてOKです。

""" 入力を受け取る """
S = input()

""" 「文字列」を「文字列のリスト」に変換 """
str_list = S.split()

print(str_list)
"""
入力例 1 
>>> ['3', '9', '2']
"""

不要な空白を削除して数字だけに分解できましたが、今のままではひとつひとつの数が文字列のままですので、次はこれらを整数に変換したいです。求めている処理のイメージは次のようになります。

""" 入力を受け取る """
S = input()

""" 「文字列」を「文字列のリスト」に変換 """
str_list = S.split()

""" 「文字列のリスト」を「整数のリスト」に変換 """
int_list = []
for str_num in str_list:
    int_num = int(str_num)
    int_list.append(int_num)

print(int_list)
"""
入力例 1 
>>> [3, 9, 2]
"""

最初にinput関数で文字列として受け取った入力データを整数のリストにまで変換できました。あと一歩です。

Pythonにはリストから複数の変数に値を代入できるアンパックという機能があります。たとえば長さ 3 のリストに対して変数を3つ用意すると、順番通りに値が代入されます。

入力で受け取る数字は左から$${A、B、D}$$なのでこの3つと整数のリストを '=' 演算子で結ぶと、、、完成です!

""" 入力を受け取る """
S = input()

""" 「文字列」を「文字列のリスト」に変換 """
str_list = S.split(' ')

""" 「文字列のリスト」を「整数のリスト」に変換 """
int_list = []
for str_num in str_list:
    int_num = int(str_num)
    int_list.append(int_num)

""" 完成 !! """
A, B, D = int_list

目的は達成されましたがあまりにも面倒です。

そして、この面倒なコードを1行に縮めてくれていたのがmap関数です。

使い方としては「map(関数っぽいやつ, リストっぽいやつ)」というような書き方をします。「関数っぽいやつ」「リストっぽいやつ」はそれぞれcallableとかiterableとか言うのですが今は無視します。

map関数は「リストっぽいやつ」の中身ひとつひとつに「関数っぽいやつ」の処理を行った後の「リストっぽいやつ」を返してくれます。

わけわかんないですね。
ですが、やっていることは上のコードのforループの部分と同じです。

具体例として、最初のコードをもう一度みてみましょう。

A, B, D = map(int, input().split())

input().split()で「受け取った入力を文字列のリストに変換」しました。リストなので当然「リストっぽいやつ」です。
ここでの「関数っぽいやつ」はint関数です。文字列のリストの中身ひとつひとつにint関数の処理を行うので、整数の「リストっぽいやつ」が返ってきます。

「リストっぽいやつ」もリストと同様にアンパックで値を代入することができます。

長くなりましたが、以上の手順により複数の変数で入力を受け取ることができました。


ここから問題の内容に入ります。
「初項$${A}$$、末項$${B}$$、公差$${D}$$の等差数列」を求めてねということでした。

今回は求められている等差数列が必ず存在することが制約で保証されているので、$${A}$$から$${B}$$まで$${D}$$個おきに値を取り続ければそれが答えになります。

$${A}$$から$${B}$$までの数をすべてチェックすればよいのでforループでおなじみrange関数の出番です。range関数は第3引数に「値の間隔」を指定できるのでそれを$${D}$$で指定します。

""" 入力を受け取る """
A, B, D = map(int, input().split())

""" 問題を解く処理 """
ans = []
for i in range(A, B + 1, D):
    ans.append(i)

ほぼ問題が解けたので答えを出力します。
数列を空白区切りで出力するという、A問題にしてはテクニカルな出力形式です。

forループの中で1つずつ出力する方法もありますが、今回はmap関数の説明でアンパックについて触れたのでそれを利用します

リストのアンパックは変数への代入だけでなく、関数に値を渡すときにも使えます。その場合にはアスタリスク(*)をリスト変数にくっつけます。

print関数は複数の引数を渡すとそれぞれの値を空白区切りで出力するので今回の問題にピッタリです。
数列は変数 ans に記録しているので ans をアンパックして出力します。

""" 入力を受け取る """
A, B, D = map(int, input().split())

""" 問題を解く処理 """
ans = []
for i in range(A, B + 1, D):
    ans.append(i)

""" 答えを出力 """
print(*ans)

これでACとなります。
なお、forループの中で1つずつ出力する方法は以下の通りです。

A, B, D = map(int, input().split())
for i in range(A, B + 1, D):
    print(i, end=' ')

入力以外の部分を1行で書くこともできます。

A, B, D = map(int, input().split())
print(*[i for i in range(A, B + 1, D)])

また、range関数を使わなくてもこの問題を解くことができます。「初項 $${A}$$、公差$${D}$$の等差数列」の各項の値は  $${k}$$ を 0 以上の整数として $${A + kD}$$ と表すことができます。これは $${k = 0}$$ から $${A  + kD}$$ が$${B}$$を超えるまで $${k}$$ に 1 を足し続けるwhileループを書けばよいです。

A, B, D = map(int, input().split())
k = 0
ans = []
while A + k * D <= B:
    ans.append(A + k * D)
    k += 1
print(*ans)

$${k}$$ を用意しなくてもOKです。

A, B, D = map(int, input().split())
ans = []
while A <= B:
    ans.append(A)
    A += D
print(*ans)

【余談】
アスタリスク(*)を使うアンパックはいろんなことができます。たとえば以下のコードを実行するとどうなるでしょうか。気になる方は試してみましょう。
参考: ABC167 - C  Skill Up

c, *a = [1, 2, 3, 4, 5]
print(c, a)

x, *y, z = [1, 2]
print(x, y, z)

【解説】 B問題 - Append

コンテスト中の提出
https://atcoder.jp/contests/abc340/submissions/50145248


  • 数列$${A}$$は空の状態から始まる。

  • クエリが$${Q}$$個与えられる。

  • クエリが$${( 1, x )}$$の形式の場合、$${A}$$の末尾に$${x}$$を追加する。

  • クエリが$${( 2, k )}$$の形式の場合、$${A}$$の後ろから$${k}$$番目を出力してね。


「クエリが$${Q}$$個与えられる」
この形式を見たときはまず制約を確認して、問題文のとおりに処理を行ったとき計算量がどうなるか考えます。

今回は「$${Q\leq 100}$$」「数列$${A}$$は空の状態からスタート」であるのと、$${A}$$を変化させる処理が「末尾に$${x}$$を追加する」だけなので$${A}$$の要素数は最大でも 100 です。
要素数 100 の配列を100回全探索したとしても実行時間 2 sec は余裕です。

というわけで問題文のとおりに処理を行って大丈夫な予感がするので、クエリに素直に従ったコードを書くことにします。はじめ$${A}$$を空のリストとして宣言しておきます。

クエリ 1 の「$${A}$$の末尾に$${x}$$を追加する」処理はB問題のタイトルにもなっている通り、リスト$${A}$$に$${x}$$をappendすればOKです。

クエリ 2 の「$${A}$$の後ろから$${k}$$番目を出力する」のはPythonであれば簡単です。リスト$${A}$$の一番後ろを参照する場合$${A[-1]}$$とするのと同じように、後ろから$${k}$$番目を参照する場合は$${A[-k]}$$でOKです。普通のリストアクセスのように1-indexから0-indexに変換しようとして $${k-1}$$ としないようにしましょう。
今回は制約で常に後ろから$${k}$$番目の要素が存在することが保証されているのでリスト外参照は気にしなくてよいです。

Q = int(input())
A = []
for _ in range(Q):
    t, x = map(int, input().split())
    if t == 1:
        A.append(x)
    else:
        print(A[-x])

後ろから$${k}$$番目の参照$${A[-k]}$$を知らなかった、あるいはそれを使えない言語を使用している場合でも「要素数$${n}$$の数列$${A}$$において、後ろから$${k}$$番目である要素は前から$${n -k+1}$$番目である」ことを考えることで解くことができます。


すこし別の視点で解法を考えてみましょう。
クエリ 1 の「末尾に追加」を「先頭に追加」と読み換えると、クエリ 2 の「後ろから$${k}$$番目」を「前から k 番目」とすることができます。この問題では必須ではないですが、問題文に書かれている操作を別の操作に置き換えることで解きやすくするテクニックは競プロにおいてとても重要です。

リストの先頭に$${x}$$を追加する処理はappendではなくinsertで行います。insert関数は指定されたインデックス位置に要素を追加する関数なので、 0 を指定することで先頭に追加できます。クエリ 2 で要素を前から見るときは1-indexを0-indexに変換するのを忘れないようにしましょう。

Q = int(input())
A = []
ans = []
for _ in range(Q):
    t, x = map(int, input().split())
    if t == 1:
        A.insert(0, x)
    else:
        print(A[x - 1])

今回の問題はこれでもACになります。

ですが、実はinsert関数には大きな落とし穴があります。
それは計算量です。長さ$${N}$$のリストに対してinsert関数を使うと、1回のクエリに通常$${O(N)}$$の計算量がかかってしまいます。仮に$${Q = 2 × 10 ^ 5}$$でずっとinsertクエリばかりのテストケースがあれば、たとえば$${10 ^ 5 + 1}$$回目以降は$${O(10 ^ 5)}$$が$${10 ^ 5}$$回繰り返されることになりTLEになります。

「じゃあ先頭に追加する処理が O(1) のデータ構造を使えばいいじゃん!dequeの出番だ!」

from collections import deque
Q = int(input())
A = deque()
ans = []
for _ in range(Q):
    t, x = map(int, input().split())
    if t == 1:
        A.appendleft(x)
    else:
        print(A[x - 1])

これもACになりますが、dequeにも同様に計算量の落とし穴があります。(ここではPython標準ライブラリのcollections.dequeに限った話をします。)
dequeはインデックスによる要素へのアクセスが遅いです。公式ドキュメントによると両端は$${O(1)}$$でアクセスできる一方で、配列の中央付近では$${O(N)}$$になるようです。リストであれば常に$${O(1)}$$です。

dequeを使うとTLEする問題例: ABC335 - C  Loong Tracking

【解説】 C問題 - Divide and Divide

コンテスト中の提出
https://atcoder.jp/contests/abc340/submissions/50151425


  • 最初、整数$${N}$$が黒板に1つ書かれている。

  • 黒板に書かれている 2 以上の整数 $${x}$$ を1つ消して$${\Large\lfloor \frac{x}{2} \rfloor}$$, $${\Large\lceil \frac{x}{2} \rceil}$$に分解する。

  • 操作にコストとして$${x}$$円かかる。

  • 最終的に何円かかるか出力してね。


問題文に素直に従ってコードを書いてみます。

from math import ceil, floor
N = int(input())
A = [N]
ans = 0
while A:
    x = A.pop()
    if x >= 2:
        ans += x
        if floor(x / 2) >= 2:
            A.append(floor(x / 2))
        if ceil(x / 2) >= 2:
            A.append(ceil(x / 2))
print(ans)

入力例 3 の$${N = 100000000000000000}$$で全く結果が返ってきません。TLEです。

$${dp[x]}$$: 整数$${x}$$から操作を始めるときにかかる金額の総和
のようなリストを前計算で用意しようにも 2 secでは$${x \leq 10^8}$$程度しか確保できず、この問題を解くのに不十分です。

よく分からないのでいったん図を書いてみます。

図1. 同じ形が現れる

同じ整数の下には同じ形がぶら下がっています。つまり同じ整数から操作を始めたとき払う金額も同じになるので、同じ形の部分の計算は省略したいです。
すなわち一度見た整数$${x}$$について払った金額の総和を記録しながら操作を進め、二度目にその整数を見たときは記録した金額の総和を$${O(1)}$$で取り出せるようにすることを目指します。なお、こういった記録のことを「メモ化」と言います。

$${N \leq 10^{17}}$$の長さのリストを用意できないので連想配列が必要です。dictでもいいですが、使い勝手のよいdefaultdictを使うことにします。

from collections import defaultdict

ここで、$${x}$$に払う金額の総和を知るためには$${\Large\lfloor \frac{x}{2} \rfloor}$$, $${\Large\lceil \frac{x}{2} \rceil}$$のそれぞれに払う金額の総和を知っている必要があります。
つまり、操作は大きい数から小さい数へと進みますが、メモ化は小さい数から大きい数へ進んでいく必要があります。

この一連の処理の流れを図にしてみます。(見やすさのため一部の整数を省略しています。)

図2. 操作の進み方

「上の処理から先に始まるが、下の処理から先に終わっていく」。First-In-Last-Out 的な処理の進み方はまさに再帰関数の挙動です。この図をグラフとして捉えるならば深さ優先探索(DFS)とも言えます。

よって再帰関数を実装します。メモ化しながら再帰処理するのでメモ化再帰と言ったりします。$${values[x]}$$に「$${x}$$に払う金額の総和」をメモしていきます。

from math import ceil, floor
from collections import defaultdict
def calc_value(x):
    if x < 2:
        return 0
    if values[x] == 0:
        a = floor(x / 2)
        b = ceil(x / 2)
        values[a] = calc_value(a)
        values[b] = calc_value(b)
        values[x] = values[a] + values[b] + x
    return values[x]
N = int(input())
values = defaultdict(int)
ans = calc_value(N)
print(ans)

以上でこの問題を解くことができました。
と、言いたいのですがWAが出てしまいました。しかもテストケース11個中5個も。

原因は切り上げ、切り捨てをmathモジュールのfloor関数、ceil関数で行っている部分です。正確には、'/' 演算子で小数 (float型) を経由していることが原因です。

やや極端な話をします。たとえば$${\Large\frac{6}{2}}$$を人間が計算すると$${3}$$になりますが、これをコンピュータが計算するとfloat型の誤差で$${2.99999…}$$になったりすることがあります。$${3}$$は切り上げしても切り捨てしても$${3}$$なのに対して$${2.99999…}$$だと$${3}$$と$${2}$$になります。
実際には大きい数を扱う時に起こる現象で、浮動小数点数の精度が15桁であることに起因します。今回の問題はそれを狙って$${N \leq 10^{17}}$$の制約になっているんですね。

from math import floor
print(floor(9999999999999999 / 2))

"""
>>> 5000000000000000 # ?!
"""

競プロにおいて整数を扱う問題ではなるべく常に整数で扱うようにし、小数を経由することによる誤差を発生させないよう気を付ける必要があります。

小数を経由しない切り上げ、切り捨て処理は次のコードで可能です。

""" 切り捨て """
x = b // a

""" 切り上げ """
y = (b + a - 1) // a

これを利用してWAになったコードを修正します。

from collections import defaultdict
def calc_value(x):
    if x < 2:
        return 0
    if values[x] == 0:
        a = x // 2
        b = (x + 1) // 2
        values[a] = calc_value(a)
        values[b] = calc_value(b)
        values[x] = values[a] + values[b] + x
    return values[x]
N = int(input())
values = defaultdict(int)
ans = calc_value(N)
print(ans)

これでACになりました。

メモ化再起を勝手にいい感じにやってくれるlru_cacheなるものも存在します。関数の頭に「@lru_cache」と付け加えることで利用できます。公式解説によるとこれより使い勝手のよいcacheなるものもあるらしいです。何がどう違うか分かりませんが、とにかく知っていると便利です。

from functools import lru_cache

@lru_cache
def dfs(x):
    if x < 2:
        return 0
    a = x // 2
    b = (x + 1) // 2
    return dfs(a) + dfs(b) + x
N = int(input())
ans = dfs(N)
print(ans)

Twitterを見ているとメモ化再起をせずに数学力で解く面白そうな解法を取っている方もいましたが、私は数学が苦手なのでよく理解できませんでした。

再帰関数の処理順序の話については、より詳しく分かりやすい記事があるのでご紹介します。私は茶コーダーの頃にこの記事を読んで再帰関数に対して考えがクリアになりました。


【余談】
C問題のページを開いたとき問題文に猛烈な既視感を覚えました。あとあと調べてみるとARC135のA問題だったようで、2週間くらい前にたまたま解いてたみたいです。初見は3WA 3TLEだったので運に救われてるかも。

類題: ARC135 - A  Floor, Ceil - Decomposition

【解説】 D問題 - Super Takahashi Bros.

コンテスト中の提出
https://atcoder.jp/contests/abc340/submissions/50160025


  • ステージ$${1}$$からスタート

  • $${A_{i}}$$​ 秒でステージ$${i}$$をクリアする。ステージ$${i+1}$$を遊べるようになる。

  • $${B_{i}}$$​ 秒でステージ$${i}$$をクリアする。ステージ$${X_{i}}$$を遊べるようになる。

  • ステージ$${N}$$を遊べるのは最短で何秒後か出力してね


一見DPっぽい問題です。
しかし制約をよく見ると$${i \le X_{i}}$$が保証されておらず、ステージ10からステージ1に戻ることもありそうです。「$${i}$$までの確定した情報を利用して$${i + 1}$$以降の情報を確定させていく」実装をしようとして、ステージ10まで確定した後にステージ1の情報が更新されては困ります。

「ステージ$${S}$$からステージ$${T}$$に$${X}$$秒で移動可能」という設定からグラフの最短経路問題として解くことができそうなので入力を隣接リストとして受け取ります。
ステージを頂点として移動可能なステージ間に有向辺を張り、ステージクリアにかかる時間を辺のコストとします。

N = int(input())
G = [[] for _ in range(N + 1)]
for i in range(1, N):
    a, b, x = map(int, input().split())
    G[i].append((i + 1, a))
    G[i].append((x, b))

「グラフの最短経路問題だから幅優先探索(BFS)で解けるかも?」と思いきや、全ての辺のコストが等しいとは限らないグラフでBFSを使うとWAになることがあります。今回は負のコストをもつ辺がないのでダイクストラ法でいけそうです。

BFSとダイクストラ法の違いについては下記の記事が分かりやすいです。

やりたいことはダイクストラ法も幅優先探索も同じで、「辺長が非負で始点が1点固定の最短路問題」を効率的に解くために、暫定最短距離が最も小さい頂点を選んで、そこから伸びる辺で他の頂点の暫定最短距離を更新することを繰り返したいのです。

01-BFSのちょっと丁寧な解説 - ARMERIA

というわけでダイクストラ法を実装します。

from heapq import *
inf = 1 << 60
def dijkstra(start):
    hq = [(0, start)]
    seen = [False] * (N + 1)
    dist = [inf] * (N + 1)
    dist[start] = 0
    while hq:
        _, curr = heappop(hq)
        if seen[curr]:
            continue
        seen[curr] = True
        for nxt, cost in G[curr]:
            if dist[nxt] > dist[curr] + cost:
                dist[nxt] = dist[curr] + cost
                heappush(hq, (dist[nxt], nxt))
    print(dist[N])

ダイクストラ法は「探索候補の頂点のうち、現時点でスタート地点から最小コストで訪れることができる頂点」をひとつ選んで「その頂点と繋がっているコスト未確定の頂点を探索候補に追加する」ことを探索候補がなくなるまでくり返すアルゴリズムです。コストが確定するタイミングはその頂点を探索したときです。

「探索候補の頂点のうち、現時点でスタート地点から最小コストで訪れることができる頂点」を取り出すのは優先度付きキュー(heapq)を使うと効率的です。
heapqは最小値を$${O(1)}$$で参照、$${O(logN)}$$で取り出しができるデータ構造です。要素の追加は$${O(logN)}$$、ヒープ化は$${O(N)}$$です。
毎回ソートするのと比べるとめちゃめちゃ高速です。

以上で解けました。


【余談】
あとあと気づきましたが問題の設定・タイトルともにマリオを意識した問題でしたね。マリオシリーズでダイクストラといえば、RTA in Japan Winter 2022にてチャート構築にダイクストラを使用した走者さんがいたことを思い出しました。


解説はここまでです。ありがとうございました。

 E - Mancala 2


  • $${N}$$個の箱がある

  • 箱$${i}$$には$${A_{i}}$$個のボールが入っている。

  • 箱$${B_{i}}$$の中のボールをすべて取り出して1個ずつ他の箱に移す

  • 最終的な箱の状態を出力してね


「E問題で475点ということはセグ木じゃない簡単な解法もあるのかな?」

最初はimos法で解けるものだと思っていて途中でダメなことに気付く。いろいろ考えたところで区間加算の遅延セグ木でしかなかったです。
ほとんど実装したことのないものは本番でも実装できません。

いろいろな方の記事を参考にupsolveしました。(提出コード
区間和を求めるqueryメソッドが思ったような動作にならず、まだまだ理解不足です。今回は区間和を求める必要がなかったのでコードから省きました。
いま記事を書いてて思ったんですが op や e を外から渡すような汎用的なものを最初から作ろうとしない方がよかった気がします。

自分の整理のためコードをひとつずつ見てみます。まずは初期化から。

def op(a, b): return a + b
def e(): return 0
class LazySegtree:
    def __init__(self, init_val: list[int], op: callable, e: callable) -> None:
        self.root = 1 # 1-indexの明示
        self.m = len(init_val)
        self.n = 1 << (self.m - 1).bit_length()
        self.e = e()
        self.op = op
        self.lazy = [0] * (2 * self.n)
        self.node = [self.e] * (2 * self.n)
        for i, val in enumerate(init_val):
            self.node[i + self.n] = val
        for k in reversed(range(1, self.n)):
            self.node[k] = self.op(self.node[2 * k], self.node[2 * k + 1])

1-indexかつ区間を$${[l, r)}$$で実装しました。

self.n = 1 << (self.m - 1).bit_length()

元の配列$${A}$$の長さを$${2^p}$$の形になるように引き延ばし、その長さを$${n}$$とします。ツリー全体の大きさは$${2n}$$になります。

次に、親ノードから子ノードへの伝搬処理。遅延セグ木の「遅延」の部分です。やっていることは親が貯め込んでいた値を子に振り分けているだけなのでまだ分かりやすいです。

def op(a, b): return a + b
def e(): return 0
class LazySegtree:
    def _eval_all(self) -> None:
        for k in range(self.root, self.n):
            self._propagates(k)
    def _eval(self, l: int, r: int) -> None:
        for k in reversed(list(self._gindex(l, r))):
            self._propagates(k)
    def _propagates(self, k: int) -> None:
        if self.lazy[k] == 0:
            return
        self.node[2 * k] = self.op(self.node[2 * k], self.lazy[k])
        self.node[2 * k + 1] = self.op(self.node[2 * k + 1], self.lazy[k])
        self.lazy[2 * k] = self.op(self.lazy[2 * k], self.lazy[k])
        self.lazy[2 * k + 1] = self.op(self.lazy[2 * k + 1], self.lazy[k])
        self.lazy[k] = 0

次は伝搬処理を効率的に行うコードです。
操作内容によっては親が値を貯め込んだままでもよい区間があります。(あるそうです。)そういった区間の伝搬処理は後回しにして、伝搬処理が必要なノードだけを抽出します。
見よう見まねで実装しましたがgindexの g ってなんなんでしょう。generator?

class LazySegtree:
    def _gindex(self, l: int, r: int) -> Generator[int, None, None]:
        l += self.n
        r += self.n
        lm = l >> (l & -l).bit_length()
        rm = r >> (r & -r).bit_length()
        while l < r:
            if l <= lm:
                yield l
            if r <= rm:
                yield r
            l >>= 1
            r >>= 1
        while l:
            yield l
            l >>= 1

まずこれです。初見で何をしてるか分かりませんでした。

lm = l >> (l & -l).bit_length()
rm = r >> (r & -r).bit_length()
図3. 伝搬処理の効率化

この$${lm, rm}$$以下の$${l, r}$$、つまり$${lm, rm}$$またはその先祖ノードが伝搬処理の対象です。$${l = r}$$になった後も根まで登っていきます。

まだ自分の中でもうまく言語化できてないんですが、区間の両端それぞれの先祖のうち初めて現れる右ノード (上図 1101) の親 (上図 110) まで伝搬を行えば、その子 (上図 1101) は一部区間を完全に被覆しているから伝搬をしなくてよい・・・感じなんですかね。

最後に区間加算のコードです。いかにも加算してそうな部分と、子に加算した後に親に伝搬させてる部分はなんとなく分かります。

def op(a, b): return a + b
class LazySegtree:
    def add(self, l: int, r: int, x: int) -> None:
        ids = self._gindex(l, r)
        l += self.n
        r += self.n
        while l < r:
            if l & 1 == 1:
                self.node[l] = self.op(self.node[l], x)
                self.lazy[l] = self.op(self.lazy[l], x)
                l += 1
            if r & 1 == 1:
                self.node[r - 1] = self.op(self.node[r - 1], x)
                self.lazy[r - 1] = self.op(self.lazy[r - 1], x)
            l >>= 1
            r >>= 1
        for k in ids:
            self.node[k] = self.op(self.node[2 * k], self.node[2 * k + 1])
            self.node[k] = self.op(self.node[k], self.lazy[k])

なんですかこのif文は。

while l < r:
    if l & 1 == 1:
        self.node[l] = self.op(self.node[l], x)
        self.lazy[l] = self.op(self.lazy[l], x)
        l += 1
    if r & 1 == 1:
        self.node[r - 1] = self.op(self.node[r - 1], x)
        self.lazy[r - 1] = self.op(self.lazy[r - 1], x)
    l >>= 1
    r >>= 1

そもそもの話として、セグ木は再帰関数を使って比較的シンプルな実装をすることもできますが、PyPyは再帰関数が遅いのでそのような実装だとTLEするかもしれなくて非再帰で書く必要があるんですよね。
実際に「競プロ典型90問」のセグ木回で公式解説の想定ソースコード通りにPythonで再帰型を書くとTLEになった記憶があります。

非再帰で書くと上のようなwhileループになります。gindexメソッドも非再帰を実現するために必要なものです。

ではあのif文は何かというと、私の認識ではこんなイメージです。

図4. 非再帰加算のイメージ

あとは問題通りになんやかんやしてACしました。

節のはじめに言った通り、区間に対するクエリ処理が実装できなかったのでここまで載せたコードのどこかに誤りがあるのかもしれません。
遅延セグ木初めての実装なので許してください。


以下の記事を参考にしました。ありがとうございます。


【追記  (2024/03/02)】

「区間に対するクエリ処理が実装できなかった」と言っていましたが、「区間最小 + 区間加算」の実装を参考にしており「区間和 + 区間加算」の実装になっていなかったです。

区間和を取る場合は

  • 加算時に「現在のノードが完全に被覆している最下段のノード」の個数分の値を加算する。

  • 伝搬時にlazy値を子ノードに半分ずつ渡す。

ように実装する必要がありました。

今回の問題は区間和をとる必要がなく直接的な問題となりませんでしたが、意図と異なる実装を紹介していたため補足いたします。

""" 区間和 + 区間加算 """
class LazySegtreeSum:
    def _progagates(self, k: int):
        if self.lazy[k] == 0:
            return
        v = self.lazy[k] // 2
        self.node[2 * k] += v
        self.node[2 * k + 1] += v
        self.lazy[2 * k] += v
        self.lazy[2 * k + 1] += v
        self.lazy[k] = 0
    def _child_count(self, k: int):
        return 1 << (self.n.bit_length() - k.bit_length())
    def add(self, l: int, r: int, x: int):
        ids = self._gindex(l, r)
        l += self.n
        r += self.n
        while l < r:
            if l & 1 == 1:
                p = self._child_count(l)
                self.node[l] += x * p
                self.lazy[l] += x * p
                l += 1
            if r & 1 == 1:
                p = self._child_count(r - 1)
                self.node[r - 1] += x * p
                self.lazy[r - 1] += x * p
            l >>= 1
            r >>= 1
        for k in ids:
            self.node[k] = self.node[2 * k]  + self.node[2 * k + 1] + self.lazy[k]
    def query(self, l: int, r: int):
        self._eval(l, r)
        l += self.n
        r += self.n
        res = 0
        while l < r:
            if l & 1 == 1:
                res += self.node[l] 
                l += 1
            if r & 1 == 1:
                res += self.node[r - 1] 
            l >>= 1
            r >>= 1
        return res

あとがき

今回のコンテストの結果はABCD4完26分、順位は2238位、パフォーマンスは1111でした。
自分にとっては早く解けた方だと思ったのに水パフォ出ず。順位表を見るに今回は5完できてないと水パフォに届かないみたい。

前回のE問題がセグ木で緑diffだったり今回のE問題も遅延セグ木で緑diffなのを考えると、セグ木は緑コーダー必須のデータ構造なのかも。
水以上のイメージでした。ちゃんと習得しないとなぁ。


前回8,000字くらい書いてこりゃ大変だなぁと思っていたら今回17,000字になってしまいました。

ではまた~。


この記事が気に入ったらサポートをしてみませんか?