State-of-the-artsな物性GNNモデルALIGNNをまなぶ
はじめに
(これは解説ブログではなく、感想ブログです。)
2023年4月においても、固体物性予測のstate-of-the-artsなGNN(グラフニューラルネット)モデルは、Atomistic Line Graph Neural Network (ALIGNN)です。(5月において首位が入れ替わりました!)
ALIGNNは物質を構成する結晶構造において、原子間距離だけでなく、そのボンド同士の角度情報も取り入れることで、二体および三体の相互作用の情報を取り入れようとするモデルです。
またALIGNNは米国のNational Institute of Standards and Technology(NIST)のMaterials Genome Initiative(MGI)によって発表されたモデルであって、同じくMGIによって開発された物性ライブラリjarvis-toolsを用いています。またgraphを扱うライブラリはdglです。
さて、オーストリアの動物行動学者コンラート・ツァハリアス・ローレンツ(Konrad Zacharias Lorenz、1903-1989)が発見した動物の習性として「刷り込み」があります。つまり、物性ライブラリとして初めて使ったのはpymatgenだし、graphならpyg(pytorch_geometric)な私としては、jarvis-tools, dglよりはpymatgen, pygに親しみを感じます。つまりALIGNNもpymatgenとpygで書きたいわけです。そんなわけで書いてみました。(そしてうまくいきませんでした)
感想とか
ここにあげました。colabノートブック、alignn_pl_in_colab.ipynbは、教師データの取得からfitting, inferenceまでできるようにしてあります。実際、original のALIGNNの性能は再現していません。あとは調整で、というところまできたので公開しました。
元のALIGNNコードでは、かなり色んな機能が実装されています。一方、私はそんなに実装する気力がなかったので、かなりシンプルな実装になっています。主要な機能のみを書き写しているので、元のコードを見るより、ALIGNNの内容が分かりやすくなっているかもしれません。そこが一番の売りです。
また、originalではDGLで書かれている箇所をpytorch_geometric(PyG)で書き直そうというのが当初の目的でしたが、どうにもうまくいきませんでした。うまくいっていない箇所は、核となるEdge gated convolution layerの実装、(多分違うと信じているが)line graphの生成、の二か所が怪しいと思っています。いいわけなのですが、DGLは結構直感的に書けるのに対し、PyGは洗練されているというか、graph convolutionの考え方に沿って実装ができます。これは言い換えると、DGLは書くのが楽なのですが、PyGはPyGの実装の思想に合わせて書く必要があります。ゆえに書き換えに苦労しました。(そしてうまくいっていない)
気づいたこと
・pytorch_geometricのline graph生成コードは、multi graphに対応していない。つまり、入力graphに対して、coalesce(edge indexの重複を消す)をかましている。またcoalesceを消しただけでは、出力のedgeがずれている(原因は不明)。よって自作した。
def gen_line_graph(g:Data)-> Data:
u, v = g.edge_index
l_src, l_dst = zip(*[(i, j)
for i, dst in enumerate(v)
for j, src in enumerate(u)
if dst == src and i != j])
l_src = torch.tensor(l_src)
l_dst = torch.tensor(l_dst)
lg = Data(edge_index=torch.vstack((l_src, l_dst)))
lg.x = g.edge_attr
lg.edge_attr = compute_bond_cosines(lg)
return lg
・結晶構造についてgraphをつくるとき、cutoffを定めて、その距離内にある原子siteをnodeとして取り入れる。primitive cellの格子定数がcutoff以下であれば、graphのnodeとして取ってくる原子siteは、primitive cellの外、すなわちsuper cellの原子siteでなくてはならない。つまりnodeはsuper cellでの固有の数である。しかしALIGNNのコードでは、結晶構造に対するgraphのnodeは、cutoff内にあるsiteの数ではなく、primitive cell内のsite数である。しかし同じnode間のedgeは複数存在し、同じedge_indexであっても、edge_attribute、つまりsite間距離が異なるという表現をしている。これは結局message passingのあと、global_poolingでnodesでの平均をとるために、性質が同じnodeは特に別のnodeとして扱う必要がないためと考える。(ただしALIGNNでは、networkxでgraph表示する際だけは、supercellを作成している)下にMoS2の生成グラフを図示する。nodeはprimitive cell内のsite数=3だが、edgeはcutoff内のsuper cellを勘案した数=62だけある。
MoS2
1.0
1.5939818713880687 -2.7608568870413697 0.0
1.5939818713880687 2.7608568870413697 0.0
0.0 0.0 34.879004
Mo S
1 2
Cartesian
1.59398 -0.9202885072399822 3.7197407130000837
1.59398 0.9202885072399822 2.153121547584768
1.59398 0.9202885072399822 5.286394757415309
この記事が気に入ったらサポートをしてみませんか?