JaLMS
最新の AI 研究を日本語で解読

TinyFusion: Diffusion Transformers Learned Shallow

Gongfan Fang,  Kunjun Li11footnotemark: 1,  Xinyin Ma,  Xinchao Wang
National University of Singapore
{gongfan, kunjun, maxinyin}@u.nus.edu, [email protected]
Equal contributionCorresponding author
Abstract

拡散トランスフォーマーは画像生成において驚異的な能力を示してきたが、しばしば過剰なパラメータ化を伴い、実世界のアプリケーションにおいて相当な推論オーバーヘッドをもたらす。本稿では、エンドツーエンド学習を通じて拡散トランスフォーマーから冗長な層を除去するために設計された深さ剪定手法、TinyFusionを提案する。我々のアプローチの核心原理は、高い回復可能性を持つ剪定モデルを作成することであり、これによりファインチューニング後に強力な性能を取り戻すことができる。これを達成するために、我々は剪定を学習可能にする微分可能なサンプリング技術を導入し、将来のファインチューニングをシミュレートする共最適化パラメータと組み合わせた。先行研究が剪定後の損失や誤差の最小化に焦点を当てているのに対し、我々の手法は剪定されたモデルのファインチューニング後の性能を明示的にモデル化し最適化する。実験結果は、この学習可能なパラダイムが拡散トランスフォーマーの層剪定に大きな利点をもたらし、既存の重要度ベースおよび誤差ベースの手法を凌駕することを示している。さらに、TinyFusionはDiT、MAR、SiTなど多様なアーキテクチャにわたって強力な汎化性を示す。DiT-XLを用いた実験では、TinyFusionが事前学習コストの7%未満で浅い拡散トランスフォーマーを作成し、2×\times×倍の高速化を達成しながらFIDスコア2.86を実現し、同等の効率性を持つ競合手法を上回ることを示している。コードはhttps://github.com/VainF/TinyFusionで入手可能である。

1 Introduction

拡散トランスフォーマーは生成タスクにおける基盤的なアーキテクチャとして台頭し、画像[40, 11, 26]や動画合成[59, 25]などの分野で顕著な成功を収めている。この成功により、インターネット上で高品質な事前学習モデルが広く利用可能となり、様々な下流アプリケーションの開発を大きく加速し支援している[53, 5, 16, 55]。 しかしながら、事前学習された拡散トランスフォーマーは通常、巨大なパラメータスケールのために相当な推論コストを伴い、デプロイメントに重大な課題をもたらす。この問題を解決するため、研究コミュニティと産業界の双方から、軽量モデルの開発に対する関心が高まっている[32, 23, 12, 58]

Refer to caption
図1: 本稿は、事前学習された拡散トランスフォーマーの深さを剪定するための学習可能なアプローチを提示する。我々の手法は、層マスクの微分可能なサンプリングプロセスと重み更新を同時に最適化し、高度に回復可能な解を特定することで、剪定されたモデルがファインチューニング後も競争力のある性能を維持することを保証する。

拡散モデルの効率は通常、サンプリングステップ数[45, 46, 33, 43]、オペレータ設計[48, 7, 52]、計算精度[30, 44, 19]、ネットワーク幅[12, 3]、深さ[23, 6, 36]など、様々な要因に影響される。本稿では、ネットワークから全層を除去してレイテンシーを削減する深さプルーニング[54, 36]によるモデル圧縮に焦点を当てる。深さプルーニングは実践的に大きな利点を提供する:並列デバイスと非並列デバイスの両方で、圧縮率に対して線形の加速比を達成できるのである。例えば、本研究で示されるように、50%の幅プルーニング[12]が1.6倍の高速化しか達成できないのに対し、層の50%をプルーニングすると2倍の高速化が得られる。これにより、深さプルーニングはモデル圧縮のための柔軟で実用的な手法となるのである。

本稿は標準的な深さ方向の枝刈りフレームワークに従っている:まず重要でない層を除去し、その後枝刈りされたモデルの性能回復のために微調整を行う。文献において、拡散トランスフォーマーや一般的なトランスフォーマー向けに設計された深さ方向の枝刈り技術は、主に発見的アプローチに焦点を当てている。例えば、慎重に設計された重要度スコア[36, 6]や手動で設定された枝刈りスキーム[23, 54]などである。これらの手法は損失最小化の原則[18, 37]に従い、枝刈り後も低い損失または誤差を維持する解を特定することを目指している。本稿では、深さ方向の圧縮の文脈においてこの広く使用されている原則の有効性を調査する。実験を通じて、我々は枝刈り後に観察される較正損失と微調整後の性能との関係を検証した。これは、ランダムな枝刈りによって100,000のモデルを広範にサンプリングし、探索空間内で異なるレベルの較正損失を示すことで達成された。これに基づいて、我々は特徴類似度[6, 36]や感度分析[18]などの既存の枝刈りアルゴリズムの有効性を分析した。これらは確かに解空間において低い較正損失を達成している。しかし、これらのモデルの微調整後の性能は期待に反して低いことが多い。これは、損失最小化の原則が拡散トランスフォーマーに適していない可能性があることを示している。

これらの洞察に基づき、我々は拡散トランスフォーマーにおける効果的な層の枝刈りの基本原則を再評価した。拡散トランスフォーマーの微調整は非常に時間のかかるプロセスである。枝刈り直後に損失を最小化するモデルを探すのではなく、我々は強い回復可能性を持つ候補モデルを特定し、微調整後により優れた性能を実現することを提案する。この目標の達成は特に困難である。なぜなら、非微分可能な操作を含む枝刈りと微調整という2つの異なるプロセスを統合する必要があり、勾配降下法によって直接最適化することができないからである。

この目的のため、我々は学習可能な深さ方向の枝刈り手法を提案し、枝刈りと微調整を効果的に統合する。図1に示すように、我々は拡散トランスフォーマーの枝刈りと微調整を、層マスクの微分可能なサンプリングプロセスとしてモデル化し[17, 22, 13]、将来の微調整をシミュレートするために共同最適化された重み更新と組み合わせる。我々の目的は、回復可能性の高いネットワークがサンプリングされる可能性が高くなるように、この分布を反復的に改善することである。これは単純な戦略によって達成される:サンプリングされた枝刈り決定が強い回復可能性をもたらす場合、類似の枝刈りパターンがサンプリングされる確率が増加する。このアプローチは、潜在的に価値のある解決策の探索を促進し、効果の低いものを無視する。さらに、本稿で提案する手法は非常に効率的であり、わずか数回の学習ステップで適切な解決策が現れることを示す。

提案手法の有効性を評価するため、我々はDiTs [40]、MARs [29]、SiTs [34]を含む様々なトランスフォーマーベースの拡散モデルに対して広範な実験を行った。学習可能なアプローチは非常に効率的である。データセットに対して1エポックの訓練で拡散トランスフォーマーの冗長な層を特定し、事前学習済みモデルから高い回復可能性を持つ浅い拡散トランスフォーマーを効果的に作成することができる。例えば、TinyFusionによって枝刈りされたモデルは、50%の層を除去した直後は比較的高いキャリブレーション損失を示すが、ファインチューニングを通じて迅速に回復し、即時の損失のみを最小化するベースライン手法と比較して、事前学習コストのわずか1%で大幅に競争力のあるFIDスコア(5.73対22.28)を達成する。さらに、我々はMaskedKD変種を導入することで、回復可能性の向上における知識蒸留の役割も探究した [20, 23]。MaskedKDは、隠れ状態における巨大な、または外れ値の活性化 [47] の悪影響を軽減し、これらはファインチューニングの性能と信頼性に大きく影響する可能性がある。MaskedKDを用いることで、事前学習コストのわずか1%で、FIDスコアは5.73から3.73に改善される。訓練を事前学習コストの7%まで延長すると、FIDはさらに2.86まで低下し、深さが2倍の元のモデルよりもわずか0.4高いだけとなる。

したがって、本稿の主な貢献は、事前学習済みモデルから浅い拡散トランスフォーマーを作成する学習可能な手法にあり、これは枝刈りされたモデルの回復可能性を明示的に最適化する。この手法は、DiTs、MARs、SiTsを含む様々なアーキテクチャに対して一般的に適用可能である。

Refer to caption
図2: 提案されたTinyFusion手法は、候補解の微分可能なサンプリングを学習し、回復可能性を推定するための重み更新と共同で最適化する。このアプローチは、ファインチューニング後の強力な性能を保証する好ましい解の可能性を高めることを目的としている。訓練後、最もサンプリング確率の高い局所構造が保持される。

2 Related Works

Network Pruning and Depth Reduction.

ネットワークの枝刈りは、冗長なパラメータを除去することで事前学習済み拡散モデルを圧縮するために広く用いられているアプローチである[12, 3, 51, 31]。Diff-Pruning [12]は、UNetの幅を効率化するための勾配ベースの手法を導入し、その後、性能を回復するための簡単な微調整を行っている。SparseDM [51]は、Straight-Through Estimator (STE) [2]を通じて事前学習済み拡散モデルにスパース性を適用し、平均してFIDが1.22増加するだけでMACsを50%削減することを達成している。幅の枝刈りとスパース性はメモリオーバーヘッドの削減に役立つが、特にGPUのような並列デバイスでは、速度の向上が限定的であることが多い。そのため、過去数年間で深さの削減が大きな注目を集めている。これは、層全体を削除することで、枝刈り率に比例したより良い高速化が可能になるためである[54, 36, 24, 27, 58, 56, 28]。MoD [41]や深さを考慮したトランスフォーマー[10]のような適応的な深さ削減技術も提案されている。これらの進歩にもかかわらず、既存の手法の多くは依然として、慎重に設計された重要性基準[36, 54]、感度分析[18]、または手動で設計されたスキーム[23]などの経験的またはヒューリスティックな戦略に基づいており、微調整後の強力な性能保証をしばしば提供しない。

Efficient Diffusion Transformers.

効率的な拡散トランスフォーマーの開発は、コミュニティ内で魅力的な焦点となっており、線形注意機構[15, 48, 52]、コンパクトなアーキテクチャ[50]、非自己回帰型トランスフォーマー[4, 49, 38, 14]、プルーニング[23, 12]、量子化[30, 44, 19]、特徴キャッシング[35, 57]など、様々な観点から効率性を向上させるための重要な取り組みがなされている。本稿では、事前学習済み拡散トランスフォーマーの深さを圧縮することに焦点を当て、回復可能性を直接最適化する学習可能な手法を導入する。これにより、低い再学習コストで満足のいく結果を達成することが可能である。

3 Method

3.1 Shallow Generative Transformers by Pruning

本稿の目的は、事前学習済みモデルをプルーニングすることで浅い拡散トランスフォーマーを導出することである。簡略化のため、本稿のすべてのベクトルは列ベクトルとする。L𝐿Litalic_L層のトランスフォーマーを考え、これをΦL×D=[ϕ1,ϕ2,,ϕL]subscriptΦ𝐿𝐷superscriptsubscriptbold-italic-ϕ1subscriptbold-italic-ϕ2subscriptbold-italic-ϕ𝐿\Phi_{L\times D}=\left[\boldsymbol{\phi}_{1},\boldsymbol{\phi}_{2},\cdots,% \boldsymbol{\phi}_{L}\right]^{\intercal}roman_Φ start_POSTSUBSCRIPT italic_L × italic_D end_POSTSUBSCRIPT = [ bold_italic_ϕ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_ϕ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , bold_italic_ϕ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPTでパラメータ化する。ここで、各要素ϕisubscriptbold-italic-ϕ𝑖\boldsymbol{\phi}_{i}bold_italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTは、トランスフォーマー層のすべての学習可能なパラメータをD𝐷Ditalic_D次元の列ベクトルとして包含し、これには注意層とMLPの両方の重みが含まれる。深さプルーニングは、以下の方法で層を除去する二値層マスク𝖒L×1=[m1,m2,,mL]subscript𝖒𝐿1superscriptsubscript𝑚1subscript𝑚2subscript𝑚𝐿\boldsymbol{\mathfrak{m}}_{L\times 1}=\left[m_{1},m_{2},\cdots,m_{L}\right]^{\intercal}bold_fraktur_m start_POSTSUBSCRIPT italic_L × 1 end_POSTSUBSCRIPT = [ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , italic_m start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPTを見つけることを目指す:

xi+1=miϕi(xi)+(1mi)xi={ϕi(xi),ifmi=1,xi,otherwise,subscript𝑥𝑖1subscript𝑚𝑖subscriptbold-italic-ϕ𝑖subscript𝑥𝑖1subscript𝑚𝑖subscript𝑥𝑖casessubscriptbold-italic-ϕ𝑖subscript𝑥𝑖ifsubscript𝑚𝑖1subscript𝑥𝑖otherwisex_{i+1}=m_{i}\boldsymbol{\phi}_{i}(x_{i})+(1-m_{i})x_{i}=\begin{cases}% \boldsymbol{\phi}_{i}(x_{i}),\;&\text{if}\ m_{i}=1,\\ x_{i},\;&\text{otherwise},\\ \end{cases}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = { start_ROW start_CELL bold_italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , end_CELL start_CELL if italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 , end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , end_CELL start_CELL otherwise , end_CELL end_ROW (1)

ここで、xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTϕi(xi)subscriptbold-italic-ϕ𝑖subscript𝑥𝑖\boldsymbol{\phi}_{i}(x_{i})bold_italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )は層ϕisubscriptbold-italic-ϕ𝑖\boldsymbol{\phi}_{i}bold_italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTの入力と出力を指す。マスクを取得するために、先行研究では一般的に、プルーニング後の損失\mathcal{L}caligraphic_Lを最小化するパラダイムが用いられ、これはmin𝖒𝔼x[(x,Φ,𝖒)]subscript𝖒subscript𝔼𝑥delimited-[]𝑥Φ𝖒\min_{\boldsymbol{\mathfrak{m}}}\mathbb{E}_{x}\left[\mathcal{L}(x,\Phi,% \boldsymbol{\mathfrak{m}})\right]roman_min start_POSTSUBSCRIPT bold_fraktur_m end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT [ caligraphic_L ( italic_x , roman_Φ , bold_fraktur_m ) ]と定式化できる。しかし、我々が実験で示すように、この目的関数は識別タスクで広く採用されているものの、拡散トランスフォーマーのプルーニングには適していない可能性がある。代わりに、我々はプルーニングされたモデルの回復可能性により関心がある。これを達成するために、我々は最適化問題に追加の重み更新を組み込み、目的関数を以下のように拡張する:

min𝖒minΔΦ𝔼x[(x,Φ+ΔΦ,𝖒)]Recoverability: Post-Fine-Tuning Performance,subscript𝖒subscriptsubscriptΔΦsubscript𝔼𝑥delimited-[]𝑥ΦΔΦ𝖒Recoverability: Post-Fine-Tuning Performance\min_{\boldsymbol{\mathfrak{m}}}\underbrace{\min_{\Delta\Phi}\mathbb{E}_{x}% \left[\mathcal{L}(x,\Phi+\Delta\Phi,\boldsymbol{\mathfrak{m}})\right]}_{% \textit{Recoverability: Post-Fine-Tuning Performance}},roman_min start_POSTSUBSCRIPT bold_fraktur_m end_POSTSUBSCRIPT under⏟ start_ARG roman_min start_POSTSUBSCRIPT roman_Δ roman_Φ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT [ caligraphic_L ( italic_x , roman_Φ + roman_Δ roman_Φ , bold_fraktur_m ) ] end_ARG start_POSTSUBSCRIPT Recoverability: Post-Fine-Tuning Performance end_POSTSUBSCRIPT , (2)

ここで、ΔΦ={Δϕ1,Δϕ2,,ΔϕM}ΔΦΔsubscriptbold-italic-ϕ1Δsubscriptbold-italic-ϕ2Δsubscriptbold-italic-ϕ𝑀\Delta\Phi=\{\Delta\boldsymbol{\phi}_{1},\Delta\boldsymbol{\phi}_{2},\cdots,% \Delta\boldsymbol{\phi}_{M}\}roman_Δ roman_Φ = { roman_Δ bold_italic_ϕ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_Δ bold_italic_ϕ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , roman_Δ bold_italic_ϕ start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT }はファインチューニングからの適切な更新を表す。式2で定式化された目的関数は2つの課題を提起する:1) 層選択の微分不可能な性質により、勾配降下法を用いた直接的な最適化が妨げられる;2) 保持された層に対する内部最適化により、候補モデルを選択し評価のためにファインチューニングする必要があるため、全探索空間を探索することが計算上困難になる。これに対処するために、我々はプルーニングと回復可能性の両方を最適化可能にするTinyFusionを提案する。

3.2 TinyFusion: Learnable Depth Pruning

A Probabilistic Perspective.

本稿では、式2を確率論的観点からモデル化する。我々は、「理想的な」枝刈り手法(必ずしも一意ではない)によって生成されるマスク𝖒𝖒\boldsymbol{\mathfrak{m}}bold_fraktur_mが、ある特定の分布に従うという仮説を立てる。これをモデル化するために、可能なすべてのマスク𝖒𝖒\boldsymbol{\mathfrak{m}}bold_fraktur_mに確率値p(𝖒)𝑝𝖒p(\boldsymbol{\mathfrak{m}})italic_p ( bold_fraktur_m )を関連付け、カテゴリカル分布を形成することが直感的である。事前知識がない場合、枝刈りマスクの評価は一様分布から始まる。しかし、この初期分布から直接サンプリングすることは、膨大な探索空間のため非常に非効率的である。例えば、28層のモデルを50%枝刈りする場合、(2814)=40,116,600binomial281440116600\binom{28}{14}=40,116,600( FRACOP start_ARG 28 end_ARG start_ARG 14 end_ARG ) = 40 , 116 , 600の可能な解を評価する必要がある。この課題を克服するため、本稿では評価結果をフィードバックとして使用し、マスク分布を反復的に改善できる高度で学習可能なアルゴリズムを導入する。基本的な考え方は、特定のマスクが良好な結果を示した場合、類似したパターンを持つ他のマスクも潜在的な解である可能性が高く、したがって後続の評価でサンプリングされる可能性を高くすべきであり、有望な解に焦点を当てたより集中的な探索を可能にするというものである。しかし、「類似パターン」の定義はまだ不明確である。

Sampling Local Structures.

本稿では、図2に示すような局所構造が、異なるマスク間の関係をモデル化する効果的な基準点として機能することを示す。枝刈りマスクが特定の局所構造をもたらし、微調整後に競争力のある結果を生み出す場合、同じ局所パターンをもたらす他のマスクも正の解である可能性が高い。これは、元のモデルをK𝐾Kitalic_K個の重複しないブロックに分割することで達成できる。これらをΦ=[Φ1,Φ2,,ΦK]ΦsuperscriptsubscriptΦ1subscriptΦ2subscriptΦ𝐾\Phi=\left[\Phi_{1},\Phi_{2},\cdots,\Phi_{K}\right]^{\intercal}roman_Φ = [ roman_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , roman_Φ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPTと表す。簡単のため、各ブロックΦk=[ϕk1,ϕk2,,ϕkM]subscriptΦ𝑘superscriptsubscriptitalic-ϕ𝑘1subscriptitalic-ϕ𝑘2subscriptitalic-ϕ𝑘𝑀\Phi_{k}=\left[\phi_{k1},\phi_{k2},\cdots,\phi_{kM}\right]^{\intercal}roman_Φ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ italic_ϕ start_POSTSUBSCRIPT italic_k 1 end_POSTSUBSCRIPT , italic_ϕ start_POSTSUBSCRIPT italic_k 2 end_POSTSUBSCRIPT , ⋯ , italic_ϕ start_POSTSUBSCRIPT italic_k italic_M end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPTが正確にM𝑀Mitalic_M層を含むと仮定するが、実際には異なる長さを持つことができる。グローバルな層の枝刈りを行う代わりに、我々は局所的な層の枝刈りのためのN:Mスキームを提案する。ここで、M𝑀Mitalic_M層を持つ各ブロックΦksubscriptΦ𝑘\Phi_{k}roman_Φ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPTに対して、N𝑁Nitalic_N層が保持される。これにより、局所的な二値マスク𝖒=[𝖒1,𝖒2,,𝖒K]𝖒superscriptsubscript𝖒1subscript𝖒2subscript𝖒𝐾\boldsymbol{\mathfrak{m}}=[\boldsymbol{\mathfrak{m}}_{1},\boldsymbol{\mathfrak% {m}}_{2},\ldots,\boldsymbol{\mathfrak{m}}_{K}]^{\intercal}bold_fraktur_m = [ bold_fraktur_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_fraktur_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_fraktur_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPTのセットが生成される。同様に、局所マスク𝖒ksubscript𝖒𝑘\boldsymbol{\mathfrak{m}}_{k}bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPTの分布はカテゴリカル分布p(𝖒k)𝑝subscript𝖒𝑘p(\boldsymbol{\mathfrak{m}}_{k})italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )を用いてモデル化される。我々は局所的な二値マスクを独立にサンプリングし、それらを組み合わせて枝刈りを行う。これは以下の結合分布を示す:

p(𝖒)=p(𝖒1)p(𝖒2)p(𝖒K)𝑝𝖒𝑝subscript𝖒1𝑝subscript𝖒2𝑝subscript𝖒𝐾p(\boldsymbol{\mathfrak{m}})=p(\boldsymbol{\mathfrak{m}}_{1})\cdot p(% \boldsymbol{\mathfrak{m}}_{2})\cdots p(\boldsymbol{\mathfrak{m}}_{K})italic_p ( bold_fraktur_m ) = italic_p ( bold_fraktur_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⋅ italic_p ( bold_fraktur_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ⋯ italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) (3)

いくつかの局所分布p(𝖒k)𝑝subscript𝖒𝑘p(\boldsymbol{\mathfrak{m}}_{k})italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )が対応するブロックで高い信頼性を示す場合、システムはそれらの正のパターンを頻繁にサンプリングする傾向があり、他の局所ブロックでアクティブな探索を継続する。この概念に基づき、我々は上記のプロセスを学習可能にするための微分可能なサンプリングを導入する。

Differentiable Sampling.

局所マスク𝖒ksubscript𝖒𝑘\boldsymbol{\mathfrak{m}}_{k}bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPTのサンプリングプロセスを考える。これは局所ブロックΦksubscriptΦ𝑘\Phi_{k}roman_Φ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPTに対応し、カテゴリカル分布p(𝖒k)𝑝subscript𝖒𝑘p(\boldsymbol{\mathfrak{m}}_{k})italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )によってモデル化される。N:Mスキームにより、(MN)binomial𝑀𝑁\binom{M}{N}( FRACOP start_ARG italic_M end_ARG start_ARG italic_N end_ARG )個の可能なマスクが存在する。我々は、すべての可能なマスクを列挙するための特殊な行列𝖒^N:Msuperscript^𝖒:𝑁𝑀\hat{\boldsymbol{\mathfrak{m}}}^{N:M}over^ start_ARG bold_fraktur_m end_ARG start_POSTSUPERSCRIPT italic_N : italic_M end_POSTSUPERSCRIPTを構築する。例えば、2:3の層枝刈りは候補行列𝖒^2:3=[[1,1,0],[1,0,1],[0,1,1]]superscript^𝖒:23110101011\hat{\boldsymbol{\mathfrak{m}}}^{2:3}=\left[\left[1,1,0\right],\left[1,0,1% \right],\left[0,1,1\right]\right]over^ start_ARG bold_fraktur_m end_ARG start_POSTSUPERSCRIPT 2 : 3 end_POSTSUPERSCRIPT = [ [ 1 , 1 , 0 ] , [ 1 , 0 , 1 ] , [ 0 , 1 , 1 ] ]をもたらす。この場合、各ブロックは3つの確率p(𝖒k)=[pk1,pk2,pk3]𝑝subscript𝖒𝑘subscript𝑝𝑘1subscript𝑝𝑘2subscript𝑝𝑘3p(\boldsymbol{\mathfrak{m}}_{k})=\left[p_{k1},p_{k2},p_{k3}\right]italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = [ italic_p start_POSTSUBSCRIPT italic_k 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_k 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_k 3 end_POSTSUBSCRIPT ]を持つ。簡単のため、𝖒ksubscript𝖒𝑘\boldsymbol{\mathfrak{m}}_{k}bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPTk𝑘kitalic_kを省略し、pisubscript𝑝𝑖p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTを使用して𝖒^N:Msuperscript^𝖒:𝑁𝑀\hat{\boldsymbol{\mathfrak{m}}}^{N:M}over^ start_ARG bold_fraktur_m end_ARG start_POSTSUPERSCRIPT italic_N : italic_M end_POSTSUPERSCRIPTi𝑖iitalic_i番目の要素をサンプリングする確率を表す。サンプリングプロセスを微分可能にする一般的な方法はGumbel-Softmax[22, 17, 13]である:

y=one-hot(exp((gi+logpi)/τ)jexp((gj+logpj)/τ)).𝑦one-hotsubscript𝑔𝑖subscript𝑝𝑖𝜏subscript𝑗subscript𝑔𝑗subscript𝑝𝑗𝜏y=\text{one-hot}\left(\frac{\exp((g_{i}+\log p_{i})/\tau)}{\sum_{j}\exp((g_{j}% +\log p_{j})/\tau)}\right).italic_y = one-hot ( divide start_ARG roman_exp ( ( italic_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT roman_exp ( ( italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + roman_log italic_p start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) / italic_τ ) end_ARG ) . (4)

ここで、gisubscript𝑔𝑖g_{i}italic_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTはGumbel分布Gumbel(0,1)Gumbel01\textit{Gumbel}(0,1)Gumbel ( 0 , 1 )から抽出されたランダムノイズであり、τ𝜏\tauitalic_τは温度項を指す。出力y𝑦yitalic_yはサンプリングされたマスクのインデックスである。ここでは、Straight-Through Estimator[2]がone-hot操作に適用される。one-hot操作は順伝播時に有効化され、逆伝播時には恒等関数として扱われる。one-hotインデックスy𝑦yitalic_yと候補セット𝖒^N:Msuperscript^𝖒:𝑁𝑀\hat{\boldsymbol{\mathfrak{m}}}^{N:M}over^ start_ARG bold_fraktur_m end_ARG start_POSTSUPERSCRIPT italic_N : italic_M end_POSTSUPERSCRIPTを利用して、単純なインデックス操作によってマスク𝖒p(𝖒)similar-to𝖒𝑝𝖒\boldsymbol{\mathfrak{m}}\sim p(\boldsymbol{\mathfrak{m}})bold_fraktur_m ∼ italic_p ( bold_fraktur_m )を抽出できる:

𝖒=y𝖒^𝖒superscript𝑦^𝖒\boldsymbol{\mathfrak{m}}=y^{\intercal}\hat{\boldsymbol{\mathfrak{m}}}bold_fraktur_m = italic_y start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT over^ start_ARG bold_fraktur_m end_ARG (5)

注目すべきは、τ0𝜏0\tau\rightarrow 0italic_τ → 0の場合、STEの勾配は真の勾配に近似するが、より高い分散を持ち、これは訓練にとってネガティブである[22]。したがって、通常、高い温度で訓練を開始し、時間とともに徐々に温度を下げるスケジューラが使用される。

Refer to caption
図3: 微分可能な枝刈りマスクmisubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTとLoRAを用いた回復可能性推定の順伝播の例。
Method Depth #Param Iters IS \uparrow FID \downarrow sFID \downarrow Prec. \uparrow Recall \uparrow Sampling it/s \uparrow
DiT-XL/2 [40] 28 675 M 7,000 K 278.24 2.27 4.60 0.83 0.57 6.91
DiT-XL/2 [40] 28 675 M 2,000 K 240.22 2.73 4.46 0.83 0.55 6.91
DiT-XL/2 [40] 28 675 M 1,000 K 157.83 5.53 4.60 0.80 0.53 6.91
U-ViT-H/2 [1] 29 501 M 500 K 265.30 2.30 5.60 0.82 0.58 8.21
ShortGPT [36] 28\Rightarrow19 459 M 100 K 132.79 7.93 5.25 0.76 0.53 10.07
TinyDiT-D19 (KD) 28\Rightarrow19 459 M 100 K 242.29 2.90 4.63 0.84 0.54 10.07
TinyDiT-D19 (KD) 28\Rightarrow19 459 M 500 K 251.02 2.55 4.57 0.83 0.55 10.07
DiT-L/2 [40] 24 458 M 1,000 K 196.26 3.73 4.62 0.82 0.54 9.73
U-ViT-L [1] 21 287 M 300 K 221.29 3.44 6.58 0.83 0.52 13.48
U-DiT-L [50] 22 204 M 400 K 246.03 3.37 4.49 0.86 0.50 -
Diff-Pruning-50% [12] 28 338 M 100 K 186.02 3.85 4.92 0.82 0.54 10.43
Diff-Pruning-75% [12] 28 169 M 100 K 83.78 14.58 6.28 0.72 0.53 13.59
ShortGPT [36] 28\Rightarrow14 340 M 100 K 66.10 22.28 6.20 0.63 0.56 13.54
Flux-Lite [6] 28\Rightarrow14 340 M 100 K 54.54 25.92 5.98 0.62 0.55 13.54
Sensitivity Analysis [18] 28\Rightarrow14 340 M 100 K 70.36 21.15 6.22 0.63 0.57 13.54
Oracle (BK-SDM) [23] 28\Rightarrow14 340 M 100 K 141.18 7.43 6.09 0.75 0.55 13.54
TinyDiT-D14 28\Rightarrow14 340 M 100 K 151.88 5.73 4.91 0.80 0.55 13.54
TinyDiT-D14 28\Rightarrow14 340 M 500 K 198.85 3.92 5.69 0.78 0.58 13.54
TinyDiT-D14 (KD) 28\Rightarrow14 340 M 100 K 207.27 3.73 5.04 0.81 0.54 13.54
TinyDiT-D14 (KD) 28\Rightarrow14 340 M 500 K 234.50 2.86 4.75 0.82 0.55 13.54
DiT-B/2 [40] 12 130 M 1,000 K 119.63 10.12 5.39 0.73 0.55 28.30
U-DiT-B [50] 22 - 400 K 85.15 16.64 6.33 0.64 0.63 -
TinyDiT-D7 (KD) 14\Rightarrow7 173 M 500 K 166.91 5.87 5.43 0.78 0.53 26.81
表1: 事前学習済みDiT-XL/2の層枝刈り結果。我々は2つの設定に焦点を当てる:100Kの最適化ステップによる高速訓練と、500Kステップによる十分な微調整。回復には微調整とMasked Knowledge Distillation(KDの変種、セクション4.4参照)の両方が使用される。

Joint Optimization with Recoverability.

微分可能なサンプリングにより、勾配降下法を用いて基礎となる確率を更新することが可能になる。本稿での訓練目的は、サンプリングされたマスクの回復可能性を最大化することである。我々は、式2の目的関数を、学習可能な分布を組み込んで以下のように再定式化する:

min{p(𝖒k)}minΔΦ𝔼x,{𝖒kp(𝖒k)}[(x,Φ+ΔΦ,{𝖒k}]Recoverability: Post-Fine-Tuning Performance,\min_{\{p(\boldsymbol{\mathfrak{m}}_{k})\}}\underbrace{\min_{\Delta\Phi}\;% \mathbb{E}_{x,\{\boldsymbol{\mathfrak{m}}_{k}\sim p(\boldsymbol{\mathfrak{m}}_% {k})\}}\left[\mathcal{L}(x,\Phi+\Delta\Phi,\{\boldsymbol{\mathfrak{m}}_{k}\}% \right]}_{\textit{Recoverability: Post-Fine-Tuning Performance}},roman_min start_POSTSUBSCRIPT { italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT under⏟ start_ARG roman_min start_POSTSUBSCRIPT roman_Δ roman_Φ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x , { bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∼ italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT [ caligraphic_L ( italic_x , roman_Φ + roman_Δ roman_Φ , { bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ] end_ARG start_POSTSUBSCRIPT Recoverability: Post-Fine-Tuning Performance end_POSTSUBSCRIPT , (6)

ここで、{p(𝖒k)}={p(𝖒1),,p(𝖒K)}𝑝subscript𝖒𝑘𝑝subscript𝖒1𝑝subscript𝖒𝐾\{p(\boldsymbol{\mathfrak{m}}_{k})\}=\{p(\boldsymbol{\mathfrak{m}}_{1}),\cdots% ,p(\boldsymbol{\mathfrak{m}}_{K})\}{ italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } = { italic_p ( bold_fraktur_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , ⋯ , italic_p ( bold_fraktur_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) }は異なる局所ブロックのカテゴリカル分布を指す。この定式化に基づき、我々はさらに微調整情報を訓練に組み込む方法を調査する。我々は、分布と重み更新ΔΦΔΦ\Delta\Phiroman_Δ roman_Φの共同最適化を提案する。我々の主要なアイデアは、共同訓練のために共同最適化された更新ΔΦΔΦ\Delta\Phiroman_Δ roman_Φを導入することである。更新を作成する直接的な方法は、元のネットワークを直接最適化することである。しかし、拡散トランスフォーマーのパラメータスケールは通常非常に大きく、完全な最適化は訓練プロセスを高コストかつ非効率にする可能性がある。そこで、我々はLoRA[21]などのパラメータ効率の良い微調整方法が、必要なΔΦΔΦ\Delta\Phiroman_Δ roman_Φを得るための良い選択肢となることを示す。ΦΦ\Phiroman_Φ内の単一の線形行列𝐖𝐖\mathbf{W}bold_Wに対して、我々は微調整された重みを以下のようにシミュレートする:

𝐖fine-tuned=𝐖+αΔ𝐖=𝐖+α𝐁𝐀,subscript𝐖fine-tuned𝐖𝛼Δ𝐖𝐖𝛼𝐁𝐀\mathbf{W}_{\text{fine-tuned}}=\mathbf{W}+\alpha\Delta\mathbf{W}=\mathbf{W}+% \alpha\mathbf{B}\mathbf{A},bold_W start_POSTSUBSCRIPT fine-tuned end_POSTSUBSCRIPT = bold_W + italic_α roman_Δ bold_W = bold_W + italic_α bold_BA , (7)

ここで、α𝛼\alphaitalic_αΔ𝐖Δ𝐖\Delta\mathbf{W}roman_Δ bold_Wの寄与をスケーリングするスカラーハイパーパラメータである。LoRAを使用することで、パラメータ数が大幅に削減され、異なる枝刈り決定の効率的な探索が容易になる。図3に示すように、我々はサンプリングされた二値マスク値misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTをゲートとして使用し、式1を用いてネットワークを順伝播させる。これにより、現在の層に対してサンプリングされたマスクが0の場合、層の出力が抑制される。さらに、前述のSTEは枝刈りされた層にも非ゼロの勾配を提供し、さらなる更新を可能にする。これは実践的に有用である。なぜなら、一部の層は最初は競争力がないかもしれないが、十分な微調整により競争力のある候補として浮上する可能性があるからである。

Pruning Decision.

訓練後、我々は最も高い確率を持つ局所構造を保持し、追加の更新ΔΦΔΦ\Delta\Phiroman_Δ roman_Φを破棄する。その後、標準的な微調整技術を回復のために適用することができる。

4 Experiments

4.1 Experimental Settings

我々の実験は主に、ImageNet 256 ×\times× 256 [8]におけるクラス条件付き画像生成のためのDiffusion Transformers [40]で実施された。評価については、[9, 40]に従い、公式のリファレンス画像 [9]を用いてFréchet Inception Distance (FID)、Sliding Fréchet Inception Distance (sFID)、Inception Scores (IS)、Precision、およびRecallを報告する。さらに、我々の手法をMARs [29]やSiTs [34]を含む他のモデルにも拡張した。実験の詳細は以下のセクションおよび付録に記載されている。

Refer to caption
図4: 深さの削減は、圧縮率に対する理論的な線形速度向上と密接に一致している。

4.2 Results on Diffusion Transformers

DiT.

本研究はDiTの圧縮に焦点を当てている[40]。我々は2つの主要な戦略をベースラインとして考慮する:1つ目は、手動で作成したパターンを使用してレイヤーを削除する方法である。例えば、BK-SDM [23]は、初期層や最終層などの特定のレイヤーの重要性を決定するために発見的な仮定を用いている。2つ目の戦略は、レイヤーの重要性を評価するために体系的に設計された基準に基づいており、例えばブロックの入力と出力の類似性を分析して冗長性を判断する[36, 6]。この方法は通常、プルーニング後のパフォーマンス低下を最小限に抑えることを目的としている。表1は、両方の戦略の代表例を示しており、ShortGPT [36]、Flux-Lite [6]、Diff-Pruning [12]、Sensitivity Analysis [18]、BK-SDM [23]が比較のためのベースラインとして機能している。さらに、我々の手法を、従来のDiTよりも優れたトレーニング効率を示したUViT [1]、U-DiT [50]、DTR [39]などの革新的なアーキテクチャ設計と比較評価する。

1は、事前学習済みのDiT-XL/2 [40]の圧縮に関する我々の発見を示している。このモデルは、AttentionとMLPレイヤーが交互に構成された28のトランスフォーマーレイヤーを含んでいる。提案手法は、これらの28レイヤーから{7,14,19}71419\{7,14,19\}{ 7 , 14 , 19 }サブレイヤーを持つ浅いトランスフォーマーを特定し、ファインチューニング後のパフォーマンスを最大化することを目指している。元のトレーニングコストのわずか7%(7Mステップに対して500Kステップ)で、TinyDiTはプルーニングベースの手法と新しいアーキテクチャの両方に対して競争力のあるパフォーマンスを達成している。例えば、1Mステップでゼロから学習したDiT-Lモデルは、458Mパラメータで3.73のFIDスコアを達成している。対照的に、圧縮されたTinyDiT-D14モデルは、340Mパラメータのみで、より高速なサンプリング速度(13.54 it/s vs 9.73 it/s)を持ち、大幅に改善されたFID 2.86を生成している。GPUなどの並列デバイスでは、トランスフォーマーの主なボトルネックは各レイヤー内の逐次操作から生じ、レイヤー数が増えるほど顕著になる。深さプルーニングは、トランスフォーマーレイヤー全体を削除することでこのボトルネックを軽減し、計算の深さを減らし、ワークロードを最適化する。対照的に、幅プルーニングは各レイヤー内のニューロン数を減らすだけであり、速度向上の可能性が限られている。図4に示すように、深さプルーニングは圧縮率が増加するにつれて理論的な線形速度向上に近づき、Diff-Pruning [12]などの幅プルーニング手法を上回っている。

Method Depth Params Epochs FID IS
MAR-Large 32 479 M 400 1.78 296.0
MAR-Base 24 208 M 400 2.31 281.7
TinyMAR-D16 32\Rightarrow16 277 M 40 2.28 283.4
SiT-XL/2 28 675 M 1,400 2.06 277.5
TinySiT-D14 28\Rightarrow14 340 M 100 3.02 220.1
表2: MARs [29]とSiTs [34]に対する深さプルーニングの結果。

MAR & SiT.

マスク自己回帰(MAR)[29]モデルは、連続値空間で拡散損失ベースの自己回帰フレームワークを採用し、離散的なトークン化を必要とせずに高品質な画像生成を実現している。32のトランスフォーマーブロックを持つMAR-Largeモデルが比較のベースラインとして機能している。我々のプルーニング手法を適用し、MARを16ブロックのバリアントであるTinyMAR-D16に削減し、FID 2.28を達成し、元のトレーニングコストのわずか10%(400エポックに対して40エポック)で24ブロックのMAR-Baseモデルのパフォーマンスを上回った。我々のアプローチは、データとノイズ分布を橋渡しするためにフローベースの補間フレームワークを採用するDiTアーキテクチャの拡張であるスケーラブル補間トランスフォーマー(SiT)[34]にも一般化される。28のトランスフォーマーブロックで構成されるSiT-XL/2モデルを50%プルーニングし、TinySiT-D14モデルを作成した。このプルーニングされたモデルは、元のトレーニングコストのわずか7%(1400エポックに対して100エポック)で競争力のあるパフォーマンスを維持している。表2に示すように、これらの結果は、我々のプルーニング手法が異なる拡散トランスフォーマーのバリアントに適応可能であり、モデルサイズとトレーニング時間を効果的に削減しながら強力なパフォーマンスを維持できることを示している。

4.3 Analytical Experiments

Refer to caption
図5: 候補モデルのランダムサンプリングによるキャリブレーション損失の分布。提案する学習可能な手法は、ファインチューニング後の最良のFIDを達成しているが、他のベースラインと比較して初期損失が比較的高い。
Strategy Loss IS FID Prec. Recall
Max. Loss 37.69 NaN NaN NaN NaN
Med. Loss 0.99 149.51 6.45 0.78 0.53
Min. Loss 0.20 73.10 20.69 0.63 0.58
Sensitivity 0.21 70.36 21.15 0.63 0.57
ShortGPT [36] 0.20 66.10 22.28 0.63 0.56
Flux-Lite [6] 0.85 54.54 25.92 0.62 0.55
Oracle (BK-SDM) 1.28 141.18 7.43 0.75 0.55
Learnable 0.98 151.88 5.73 0.80 0.55
表3: キャリブレーション損失を直接最小化することは、最適でない解につながる可能性がある。すべての剪定されたモデルは、知識蒸留(KD)なしで100,000ステップファインチューニングされている。我々は以下のベースラインを評価する:(1) 損失 - DiT-XLモデルをランダムに剪定して100,000個のモデルを生成し、異なるキャリブレーション損失を持つモデルをファインチューニング用に選択する;(2) メトリックベースの手法 - 感度分析やShortGPTなど;(3) オラクル - [23]に従って、最初と最後の層を保持し、中間層を均一に剪定する;(4) 学習可能 - 提案する学習可能な手法。

Is Calibration Loss the Primary Determinant?

深さ剪定における本質的な問題は、事前学習された拡散トランスフォーマーの冗長な層をどのように特定するかである。一般的なアプローチは、剪定後のキャリブレーション損失が低いモデルがより優れた性能を示すという仮定に基づいて、キャリブレーション損失を最小化することである。しかし、本節では、この仮説が拡散トランスフォーマーには当てはまらない可能性があることを示す。我々はまず、50%の比率でランダムな深さ剪定を行うことによって解空間を調査し、キャリブレーション損失が0.195から37.694の範囲にある100,000個の候補モデルを生成する(図5参照)。これらの候補から、最高および最低のキャリブレーション損失を持つモデルをファインチューニング用に選択する。注目すべきことに、両モデルは不安定な学習(NaN)や最適でないFIDスコア(20.69)などの好ましくない結果をもたらす(表3参照)。さらに、我々は感度分析[18]を実施する。これは層の除去による損失の乱れを測定することで重要な層を特定する一般的に使用される技術であり、0.21という低いキャリブレーション損失を持つモデルを生成する。しかし、このモデルのFIDスコアは最低のキャリブレーション損失を持つモデルのものと同様である。ShortGPT[36]や、入力と出力の状態間の類似性を推定または平均二乗誤差(MSE)を最小化するFluxモデル圧縮の最近のアプローチ[6]などのアプローチも同様の傾向を示す。対照的に、オラクル(しばしば競争力が低いと考えられる)やランダムに剪定されたモデルの1つなど、中程度のキャリブレーション損失を持つ手法は、それぞれ7.43と6.45のFIDスコアを達成し、最小のキャリブレーション損失を持つモデルよりも著しく優れた性能を示している。これらの発見は、キャリブレーション損失がファインチューニング後の性能にある程度影響を与える可能性があるものの、拡散トランスフォーマーにとっては主要な決定要因ではないことを示唆している。代わりに、ファインチューニング中の性能回復能力(「回復可能性」と呼ばれる)がより重要であるように見える。注目すべきことに、回復可能性を従来のメトリックで評価することは困難である。なぜなら、データセット全体にわたる学習プロセスが必要だからである。この観察は、提案手法がベースライン手法と比較して優れた結果(5.73)を達成する理由も説明している。

Pattern 𝚫𝚫\mathbf{\Delta}bold_ΔW IS \uparrow FID \downarrow sFID \downarrow Prec. \uparrow Recall \uparrow
1:2 LoRA 54.75 33.39 29.56 0.56 0.62
2:4 LoRA 53.07 34.21 27.61 0.55 0.63
7:14 LoRA 34.97 49.41 28.48 0.46 0.56
1:2 Full 53.11 35.77 32.68 0.54 0.61
2:4 Full 53.63 34.41 29.93 0.55 0.62
7:14 Full 45.03 38.76 31.31 0.52 0.62
1:2 Frozen 45.08 39.56 31.13 0.52 0.60
2:4 Frozen 48.09 37.82 31.91 0.53 0.62
7:14 Frozen 34.09 49.75 31.06 0.46 0.56
表4: 様々な剪定スキームと回復可能性推定戦略を用いて圧縮されたTinyDiT-D14モデルの性能比較。すべてのモデルは10,000ステップでファインチューニングされ、FIDスコアは64タイムステップで10,000サンプリングされた画像に対して計算される。

Learnable Modeling of Recoverability.

従来のメトリックベースの手法の限界を克服するために、本研究では剪定とモデルの回復可能性を共同で最適化する学習可能なアプローチを導入する。表3は、ローカル剪定スキームと回復可能性推定の更新戦略を含む、学習可能な手法の異なる構成を示している。固定50%の層剪定率を持つ28層のDiT-XL/2に対して、我々は1:2、2:4、7:14の3つの分割スキームを検討する。例えば、1:2スキームでは、2つのトランスフォーマー層ごとにローカルブロックを形成し、1層が剪定される。より大きなブロックはより大きな多様性をもたらすが、探索空間を大幅に拡大する。例えば、7:14スキームはモデルを2つのセグメントに分割し、それぞれ7層を保持するため、(147)×2=6,864binomial14726864\binom{14}{7}\times 2=6{,}864( FRACOP start_ARG 14 end_ARG start_ARG 7 end_ARG ) × 2 = 6 , 864の可能な解が生じる。逆に、より小さなブロックは最適化の難しさを大幅に減少させ、より大きな柔軟性を提供する。1つのブロックの分布が収束すると、他のブロックの学習はまだ進行できる。表3に示すように、1:2構成は10,000回のファインチューニング反復後に最適な性能を達成する。さらに、我々の経験的な発見は、LoRAまたは完全なファインチューニングを使用した回復可能性推定の有効性を強調している。両方の手法はファインチューニング後に肯定的な結果をもたらし、1:2スキームの下でLoRAは完全なファインチューニング(FID = 35.77)と比較して優れた結果(FID = 33.39)を達成する。これは、LoRAが学習可能なパラメータが少なく(完全なパラメータ学習の0.9%相対)、サンプリングのランダム性により効率的に適応できるためである。

Refer to caption
図6: 学習可能な剪定における2:4決定の可視化。各決定の信頼度は透明度の変化によって強調されている。1:2および7:14スキームのより多くの可視化結果は付録で利用可能である。
Refer to caption
図7: DiT-XL/2から剪定および蒸留されたTinyDiT-D14によってImageNet 224×\times×224上で生成された画像。

Visualization of Learnable Decisions.

剪定における学習可能な手法の役割についてより深い洞察を得るために、我々は図6で学習プロセスを可視化する。下から上へ、i番目の曲線は剪定されたモデルのi番目の層を表し、元のDiT-XL/2におけるその層のインデックスを表示している。この可視化は、学習反復にわたる剪定決定のダイナミクスを示しており、各データポイントの透明度はサンプリングされる確率を示している。学習可能な手法は、様々な層の組み合わせを探索し処理する能力を示している。圧縮されたモデルの7番目と8番目の層など、特定の層の剪定決定は迅速に決定され、プロセス全体を通じて安定している。対照的に、0番目の層のような他の層は、その回復可能性を推定するために追加のファインチューニングを必要とする。注目すべきことに、これらの層が十分に最適化された後、後期段階で一部の決定が変更される可能性がある。学習プロセスは最終的に高いサンプリング確率で終了し、分布がone-hot構成に近づく収束した学習プロセスを示唆している。学習後、我々は最高の確率を持つ層を選択し、その後のファインチューニングを行う。

4.4 Knowledge Distillation for Recovery

本稿では、強化された微調整手法として知識蒸留(KD)についても探究する。表5に示すように、我々はHinton[20]によって提案された通常の知識蒸留アプローチを適用し、事前学習済みのDiT-XL/2の出力を教師モデルとして監督に用いてTinyDiT-D14を微調整する。浅い学生モデルと深い教師モデルの間の出力を整合させるために平均二乗誤差(MSE)損失を採用し、これにより100Kステップでのフレシェ開始距離(FID)を5.79から4.66に効果的に削減する。

Refer to caption
(a) DiT-XL/2(教師)
Refer to caption
(b) TinyDiT-D14(学生)
図8: DiTsにおける大規模な活性化[47]の可視化。教師モデルと学生モデルの両方が隠れ状態で大きな活性化値を示している。これらの大規模な活性化を直接蒸留すると、過度に大きな損失と不安定な訓練につながる可能性がある。
fine-tuning Strategy Init. Distill. Loss FID @ 100K
fine-tuning - 5.79
Logits KD - 4.66
RepKD 2840.1 NaN
Masked KD (0.1σ0.1𝜎0.1\sigma0.1 italic_σ) 15.4 NaN
Masked KD (2σ2𝜎2\sigma2 italic_σ) 387.1 3.73
Masked KD (4σ4𝜎4\sigma4 italic_σ) 391.4 3.75
表5: 回復のための異なる微調整戦略の評価。マスクされたRepKDは、教師と学生の両方における大規模な活性化(|x|>kσx𝑥𝑘subscript𝜎𝑥|x|>k\sigma_{x}| italic_x | > italic_k italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT)を無視し、拡散トランスフォーマー間の効果的な知識転移を可能にする。

Masked Knowledge Distillation.

さらに、我々は教師から学生への隠れ状態の転移のために表現蒸留(RepKD)[42, 23]を評価する。深さの剪定は拡散トランスフォーマーの隠れ次元を変更しないため、中間隠れ状態を直接整合させることができることに注意することが重要である。実際の実装では、セクション3.2で定義されたブロックを基本単位として使用し、剪定されたDiTの局所構造が教師モデルの元の構造の出力と整合することを保証する。しかし、この単純なRepKDアプローチでは、隠れ状態における大規模な活性化により、重大な訓練の困難に直面した。図8に示すように、教師モデルと学生モデルの両方が時折大きな活性化値を示す。これらの極端な活性化を直接蒸留すると、過度に高い損失値をもたらし、学生モデルの性能を損なう可能性がある。この問題は、特定のLLM[47]など、他のトランスフォーマーベースの生成モデルでも観察されている。これに対処するため、我々は知識転移中にこれらの大規模な活性化を選択的に除外するマスクされたRepKD変種を提案する。我々は単純な閾値処理法|xμx|<kσx𝑥subscript𝜇𝑥𝑘subscript𝜎𝑥|x-\mu_{x}|<k\sigma_{x}| italic_x - italic_μ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT | < italic_k italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPTを採用し、これらの極端な活性化に関連する損失を無視する。表5に示すように、2σ2𝜎2\sigma2 italic_σおよび4σ4𝜎4\sigma4 italic_σの適度な閾値を持つマスクされたRepKDアプローチは満足のいく結果を達成し、我々の手法の堅牢性を示している。

Generated Images.

7では、既製のDiT-XL/2モデルから蒸留された学習済みTinyDiT-D14の生成画像を可視化している。SiTsとMARsのさらなる可視化結果は付録に記載されている。

5 Conclusions

本稿では、TinyFusionを紹介した。これは冗長な層を除去することによって拡散トランスフォーマーを加速させる学習可能な手法である。本手法は、プルーニングされたモデルの回復可能性を最適化可能な目的関数としてモデル化し、微分可能なサンプリングを組み込むことでエンドツーエンドの学習を可能にしている。我々の手法は、DiT、MAR、SiTなど、様々なアーキテクチャに一般化可能である。

References

  • Bao et al. [2023] Fan Bao, Shen Nie, Kaiwen Xue, Yue Cao, Chongxuan Li, Hang Su, and Jun Zhu. All are worth words: A vit backbone for diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 22669–22679, 2023.
  • Bengio et al. [2013] Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Castells et al. [2024] Thibault Castells, Hyoung-Kyu Song, Bo-Kyeong Kim, and Shinkook Choi. Ld-pruner: Efficient pruning of latent diffusion models using task-agnostic insights. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 821–830, 2024.
  • Chang et al. [2022] Huiwen Chang, Han Zhang, Lu Jiang, Ce Liu, and William T Freeman. Maskgit: Masked generative image transformer. In Conference on Computer Vision and Pattern Recognition, pages 11315–11325, 2022.
  • Chen et al. [2023] Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li. Pixart-α𝛼\alphaitalic_α: Fast training of diffusion transformer for photorealistic text-to-image synthesis, 2023.
  • Daniel Verdú [2024] Javier Martín Daniel Verdú. Flux.1 lite: Distilling flux1.dev for efficient text-to-image generation. 2024.
  • Dao et al. [2022] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • Deng et al. [2009] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • Dhariwal and Nichol [2021] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780–8794, 2021.
  • Elbayad et al. [2019] Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. Depth-adaptive transformer. arXiv preprint arXiv:1910.10073, 2019.
  • Esser et al. [2024] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. In Forty-first International Conference on Machine Learning, 2024.
  • Fang et al. [2023] Gongfan Fang, Xinyin Ma, and Xinchao Wang. Structural pruning for diffusion models. In Advances in Neural Information Processing Systems, 2023.
  • Fang et al. [2024] Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, and Xinchao Wang. Maskllm: Learnable semi-structured sparsity for large language models. arXiv preprint arXiv:2409.17481, 2024.
  • Fei et al. [2024a] Zhengcong Fei, Mingyuan Fan, Changqian Yu, Debang Li, and Junshi Huang. Scaling diffusion transformers to 16 billion parameters. arXiv preprint arXiv:2407.11633, 2024a.
  • Fei et al. [2024b] Zhengcong Fei, Mingyuan Fan, Changqian Yu, Debang Li, Youqiang Zhang, and Junshi Huang. Dimba: Transformer-mamba diffusion models. arXiv preprint arXiv:2406.01159, 2024b.
  • Gao et al. [2023] Shanghua Gao, Zhijie Lin, Xingyu Xie, Pan Zhou, Ming-Ming Cheng, and Shuicheng Yan. Editanything: Empowering unparalleled flexibility in image editing and generation. In Proceedings of the 31st ACM International Conference on Multimedia, Demo track, 2023.
  • Gumbel [1954] Emil Julius Gumbel. Statistical theory of extreme values and some practical applications: a series of lectures. US Government Printing Office, 1954.
  • Han et al. [2015] Song Han, Jeff Pool, John Tran, and William Dally. Learning both weights and connections for efficient neural network. Advances in neural information processing systems, 28, 2015.
  • He et al. [2024] Yefei He, Luping Liu, Jing Liu, Weijia Wu, Hong Zhou, and Bohan Zhuang. Ptqd: Accurate post-training quantization for diffusion models. Advances in Neural Information Processing Systems, 36, 2024.
  • Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, Jeff Dean, et al. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2(7), 2015.
  • Hu et al. [2022] Edward J Hu, yelong shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. LoRA: Low-rank adaptation of large language models. In International Conference on Learning Representations, 2022.
  • Jang et al. [2016] Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144, 2016.
  • Kim et al. [2023] Bo-Kyeong Kim, Hyoung-Kyu Song, Thibault Castells, and Shinkook Choi. Bk-sdm: Architecturally compressed stable diffusion for efficient text-to-image generation. In Workshop on Efficient Systems for Foundation Models@ ICML2023, 2023.
  • Kim et al. [2024] Bo-Kyeong Kim, Geonmin Kim, Tae-Ho Kim, Thibault Castells, Shinkook Choi, Junho Shin, and Hyoung-Kyu Song. Shortened llama: A simple depth pruning for large language models. arXiv preprint arXiv:2402.02834, 11, 2024.
  • Lab and etc. [2024] PKU-Yuan Lab and Tuzhan AI etc. Open-sora-plan, 2024.
  • Labs [2024] Black Forest Labs. FLUX, 2024.
  • [27] Youngwan Lee, Yong-Ju Lee, and Sung Ju Hwang. Dit-pruner: Pruning diffusion transformer models for text-to-image synthesis using human preference scores.
  • Lee et al. [2023] Youngwan Lee, Kwanyong Park, Yoorhim Cho, Yong-Ju Lee, and Sung Ju Hwang. Koala: self-attention matters in knowledge distillation of latent diffusion models for memory-efficient and fast image synthesis. arXiv e-prints, pages arXiv–2312, 2023.
  • Li et al. [2024a] Tianhong Li, Yonglong Tian, He Li, Mingyang Deng, and Kaiming He. Autoregressive image generation without vector quantization. arXiv preprint arXiv:2406.11838, 2024a.
  • Li et al. [2023] Xiuyu Li, Yijiang Liu, Long Lian, Huanrui Yang, Zhen Dong, Daniel Kang, Shanghang Zhang, and Kurt Keutzer. Q-diffusion: Quantizing diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 17535–17545, 2023.
  • Li et al. [2024b] Yanyu Li, Huan Wang, Qing Jin, Ju Hu, Pavlo Chemerys, Yun Fu, Yanzhi Wang, Sergey Tulyakov, and Jian Ren. Snapfusion: Text-to-image diffusion model on mobile devices within two seconds. Advances in Neural Information Processing Systems, 36, 2024b.
  • Lin et al. [2024] Shanchuan Lin, Anran Wang, and Xiao Yang. Sdxl-lightning: Progressive adversarial diffusion distillation. arXiv preprint arXiv:2402.13929, 2024.
  • Lu et al. [2022] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775–5787, 2022.
  • Ma et al. [2024a] Nanye Ma, Mark Goldstein, Michael S Albergo, Nicholas M Boffi, Eric Vanden-Eijnden, and Saining Xie. Sit: Exploring flow and diffusion-based generative models with scalable interpolant transformers. arXiv preprint arXiv:2401.08740, 2024a.
  • Ma et al. [2024b] Xinyin Ma, Gongfan Fang, Michael Bi Mi, and Xinchao Wang. Learning-to-cache: Accelerating diffusion transformer via layer caching, 2024b.
  • Men et al. [2024] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, and Weipeng Chen. Shortgpt: Layers in large language models are more redundant than you expect. arXiv preprint arXiv:2403.03853, 2024.
  • Molchanov et al. [2016] Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning convolutional neural networks for resource efficient inference. arXiv preprint arXiv:1611.06440, 2016.
  • Ni et al. [2024] Zanlin Ni, Yulin Wang, Renping Zhou, Jiayi Guo, Jinyi Hu, Zhiyuan Liu, Shiji Song, Yuan Yao, and Gao Huang. Revisiting non-autoregressive transformers for efficient image synthesis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 7007–7016, 2024.
  • Park et al. [2023] Byeongjun Park, Sangmin Woo, Hyojun Go, Jin-Young Kim, and Changick Kim. Denoising task routing for diffusion models. arXiv preprint arXiv:2310.07138, 2023.
  • Peebles and Xie [2023] William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205, 2023.
  • Raposo et al. [2024] David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, and Adam Santoro. Mixture-of-depths: Dynamically allocating compute in transformer-based language models. arXiv preprint arXiv:2404.02258, 2024.
  • Romero et al. [2014] Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta, and Yoshua Bengio. Fitnets: Hints for thin deep nets. arXiv preprint arXiv:1412.6550, 2014.
  • Salimans and Ho [2022] Tim Salimans and Jonathan Ho. Progressive distillation for fast sampling of diffusion models. arXiv preprint arXiv:2202.00512, 2022.
  • Shang et al. [2023] Yuzhang Shang, Zhihang Yuan, Bin Xie, Bingzhe Wu, and Yan Yan. Post-training quantization on diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 1972–1981, 2023.
  • Song et al. [2020] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502, 2020.
  • Song et al. [2023] Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. arXiv preprint arXiv:2303.01469, 2023.
  • Sun et al. [2024] Mingjie Sun, Xinlei Chen, J Zico Kolter, and Zhuang Liu. Massive activations in large language models. arXiv preprint arXiv:2402.17762, 2024.
  • Teng et al. [2024] Yao Teng, Yue Wu, Han Shi, Xuefei Ning, Guohao Dai, Yu Wang, Zhenguo Li, and Xihui Liu. Dim: Diffusion mamba for efficient high-resolution image synthesis. arXiv preprint arXiv:2405.14224, 2024.
  • Tian et al. [2024a] Keyu Tian, Yi Jiang, Zehuan Yuan, Bingyue Peng, and Liwei Wang. Visual autoregressive modeling: Scalable image generation via next-scale prediction. 2024a.
  • Tian et al. [2024b] Yuchuan Tian, Zhijun Tu, Hanting Chen, Jie Hu, Chao Xu, and Yunhe Wang. U-dits: Downsample tokens in u-shaped diffusion transformers. arXiv preprint arXiv:2405.02730, 2024b.
  • Wang et al. [2024] Kafeng Wang, Jianfei Chen, He Li, Zhenpeng Mi, and Jun Zhu. Sparsedm: Toward sparse efficient diffusion models. arXiv preprint arXiv:2404.10445, 2024.
  • Xie et al. [2024] Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Yujun Lin, Zhekai Zhang, Muyang Li, Yao Lu, and Song Han. Sana: Efficient high-resolution image synthesis with linear diffusion transformers. arXiv preprint arXiv:2410.10629, 2024.
  • Yang et al. [2023] Ling Yang, Zhilong Zhang, Yang Song, Shenda Hong, Runsheng Xu, Yue Zhao, Wentao Zhang, Bin Cui, and Ming-Hsuan Yang. Diffusion models: A comprehensive survey of methods and applications. ACM Computing Surveys, 56(4):1–39, 2023.
  • Yu et al. [2022] Fang Yu, Kun Huang, Meng Wang, Yuan Cheng, Wei Chu, and Li Cui. Width & depth pruning for vision transformers. In Conference on Artificial Intelligence (AAAI), 2022.
  • Yu et al. [2023] Tao Yu, Runseng Feng, Ruoyu Feng, Jinming Liu, Xin Jin, Wenjun Zeng, and Zhibo Chen. Inpaint anything: Segment anything meets image inpainting. arXiv preprint arXiv:2304.06790, 2023.
  • Zhang et al. [2024] Dingkun Zhang, Sijia Li, Chen Chen, Qingsong Xie, and Haonan Lu. Laptop-diff: Layer pruning and normalized distillation for compressing diffusion models. arXiv preprint arXiv:2404.11098, 2024.
  • Zhao et al. [2024] Xuanlei Zhao, Xiaolong Jin, Kai Wang, and Yang You. Real-time video generation with pyramid attention broadcast. arXiv preprint arXiv:2408.12588, 2024.
  • Zhao et al. [2023] Yang Zhao, Yanwu Xu, Zhisheng Xiao, and Tingbo Hou. Mobilediffusion: Subsecond text-to-image generation on mobile devices. arXiv preprint arXiv:2311.16567, 2023.
  • Zheng et al. [2024] Zangwei Zheng, Xiangyu Peng, Tianji Yang, Chenhui Shen, Shenggui Li, Hongxin Liu, Yukun Zhou, Tianyi Li, and Yang You. Open-sora: Democratizing efficient video production for all, 2024.

6 Experimental Details

Models.

我々の実験では、DiT-XL、MAR-Large、およびSiT-XLの3つのモデルの有効性を評価する。Diffusion Transformers (DiTs)は、Vision Transformer (ViT)の原理に触発され、空間的な入力をパッチのシーケンスとして処理する。DiT-XLモデルは28のトランスフォーマー層、1152のhidden size、16の注意ヘッド、および2 ×\times× 2のパッチサイズを特徴とする。トレーニングの安定性を向上させるために適応的層正規化(AdaLN)を採用しており、6億7500万のパラメータを持ち、1400エポックにわたってトレーニングされる。Masked Autoregressive models (MARs)は、自己回帰的な画像生成に特化したdiffusion transformerの変種である。これらは、離散的なトークン化なしに高品質の出力を生成するために、連続値のdiffusionロスフレームワークを利用する。MAR-Largeモデルは32のトランスフォーマー層、1024のhidden size、16の注意ヘッド、および双方向注意を含む。DiTと同様に、安定したトレーニングと効果的なトークンモデリングのためにAdaLNを組み込んでおり、4億7900万のパラメータを400エポックにわたってトレーニングする。最後に、Scalable Interpolant Transformers (SiTs)は、フローベースの補間手法を導入することでDiTフレームワークを拡張し、データとノイズ分布間のより柔軟な橋渡しを可能にする。SiT-XLモデルはDiT-XLと構造的に同一であるが、この補間アプローチを活用して、補間の選択とサンプリングダイナミクスのモジュラーな実験を容易にする。

Datasets.

我々はImageNet 256 ×\times× 256データセットを、中央クロッピングと適応的リサイズを適用して元のアスペクト比を維持し、歪みを最小限に抑えるように準備した。その後、画像は平均0.5、標準偏差0.5に正規化された。データセットを拡張するために、0.5の確率でランダムな水平フリップを適用した。Variational Autoencoder (VAE)を使用せずにトレーニングを加速するために、事前に訓練されたVAEを使用して画像から特徴を抽出した。画像はその潜在表現にマッピングされ、正規化され、結果として得られた特徴配列はトレーニング中に直接使用するために保存された。

Refer to caption
図9: 1:2プルーニング決定
Refer to caption
図10: 2:4プルーニング決定
Refer to caption
図11: 7:14プルーニング決定
Refer to caption
図12: ローカルブロックにおける学習可能な深さプルーニング
Refer to caption
図13: 2:4ブロックを用いたマスク付き知識蒸留

Training Details

学習プロセスは、図12に示されている提案された学習可能な剪定方法を用いて剪定されたモデルを取得することから始まった。剪定の決定は、ブロックサイズを持つLoRAを通じて剪定と重み更新の共同最適化によって行われた。実際には、簡略化のためにブロックサイズは2とし、MARを除いて100エポックの学習を行った。MARは40エポックの学習を行った。剪定後のパフォーマンスを向上させるために、回復段階では教師モデルから剪定された生徒モデルへの知識転移のためにMasked Knowledge Distillation(RepKD)法を採用した。RepKDアプローチは、剪定されたモデルと教師モデルの出力予測と中間隠れ状態を整合させる。詳細は次のセクションで提供される。さらに、指数移動平均(EMA)が画像生成中に更新され使用されるため、過度に小さい学習率はEMAの効果を弱め、最適でない結果につながる可能性がある。この問題に対処するために、学習全体を通じて学習率を徐々に半減させる段階的な学習率スケジューラを実装した。各ハイパーパラメータの詳細は表6に示されている。

Model Optimizer Cosine Sched. Teacher αKDsubscript𝛼KD\alpha_{\text{KD}}italic_α start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT αGTsubscript𝛼GT\alpha_{\text{GT}}italic_α start_POSTSUBSCRIPT GT end_POSTSUBSCRIPT β𝛽\betaitalic_β Grad. Clip Pruning Configs
DiT-D19 AdamW(lr=2e-4, wd=0.0) ηmin=1e-4subscript𝜂min1e-4\eta_{\text{min}}=1\text{e-4}italic_η start_POSTSUBSCRIPT min end_POSTSUBSCRIPT = 1 e-4 DiT-XL 0.9 0.1 1e-2 \rightarrow 0 1.0 LoRA-1:2
DiT-D14 AdamW(lr=2e-4, wd=0.0 ηmin=1e-4subscript𝜂min1e-4\eta_{\text{min}}=1\text{e-4}italic_η start_POSTSUBSCRIPT min end_POSTSUBSCRIPT = 1 e-4 DiT-XL 0.9 0.1 1e-2 \rightarrow 0 1.0 LoRA-1:2
DiT-D7 AdamW(lr=2e-4, wd=0.0) ηmin=1e-4subscript𝜂min1e-4\eta_{\text{min}}=1\text{e-4}italic_η start_POSTSUBSCRIPT min end_POSTSUBSCRIPT = 1 e-4 DiT-D14 0.9 0.1 1e-2 \rightarrow 0 1.0 LoRA-1:2
SiT-D14 AdamW(lr=2e-4, wd=0.0) ηmin=1e-4subscript𝜂min1e-4\eta_{\text{min}}=1\text{e-4}italic_η start_POSTSUBSCRIPT min end_POSTSUBSCRIPT = 1 e-4 SiT-XL 0.9 0.1 2e-4 \rightarrow 0 1.0 LoRA-1:2
MAR-D16 AdamW(lr=2e-4, wd=0.0) ηmin=1e-4subscript𝜂min1e-4\eta_{\text{min}}=1\text{e-4}italic_η start_POSTSUBSCRIPT min end_POSTSUBSCRIPT = 1 e-4 MAR-Large 0.9 0.1 1e-2 \rightarrow 0 1.0 LoRA-1:2
表6: マスク学習のための学習詳細とハイパーパラメータ

7 Visualization of Pruning Decisions

1111および 11は、1:2、2:4、および7:14のプルーニング方式における訓練中のプルーニング決定の動態を可視化したものである。異なる分割は異なる探索空間をもたらし、それによって様々な解決策が生まれる。1:2および2:4の方式では、わずか1エポックで良好な決定を学習できるが、7:14の方式では最適化の困難に直面する。これは (147)binomial147\binom{14}{7}( FRACOP start_ARG 14 end_ARG start_ARG 7 end_ARG )=3,432の候補が存在し、1エポック内で適切にサンプリングするには膨大すぎるためである。したがって、実際の応用では、我々は学習可能な層のプルーニングに1:2または2:4の方式を使用する。

8 Details of Masked Knowledge Distillation

Training Loss.

本稿では、事前訓練された教師モデルを模倣することで優れた生徒モデルを学習するために、標準的な知識蒸留を採用している。損失関数は以下のように形式化される:

=αKDKD+αDiffDiff+βRepsubscript𝛼KDsubscriptKDsubscript𝛼DiffsubscriptDiff𝛽subscriptRep\mathcal{L}=\alpha_{\text{KD}}\cdot\mathcal{L}_{\text{KD}}+\alpha_{\text{Diff}% }\cdot\mathcal{L}_{\text{Diff}}+\beta\cdot\mathcal{L}_{\text{Rep}}caligraphic_L = italic_α start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT Diff end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT Diff end_POSTSUBSCRIPT + italic_β ⋅ caligraphic_L start_POSTSUBSCRIPT Rep end_POSTSUBSCRIPT (8)

ここで、KDKD\mathcal{L}{\text{KD}}caligraphic_L KDは生徒モデルと教師モデルの出力間の平均二乗誤差を表す。DiffDiff\mathcal{L}{\text{Diff}}caligraphic_L Diffは元の事前訓練損失関数を表す。最後に、RepsubscriptRep\mathcal{L}_{\text{Rep}}caligraphic_L start_POSTSUBSCRIPT Rep end_POSTSUBSCRIPTは図13に示されているように、隠れ状態に適用されるマスク蒸留損失に相当し、これは剪定されたモデルと元のモデルの中間表現間の整合性を促進する。対応するハイパーパラメータαKDsubscript𝛼KD\alpha_{\text{KD}}italic_α start_POSTSUBSCRIPT KD end_POSTSUBSCRIPTαDiffsubscript𝛼Diff\alpha_{\text{Diff}}italic_α start_POSTSUBSCRIPT Diff end_POSTSUBSCRIPTおよびαRepsubscript𝛼Rep\alpha_{\text{Rep}}italic_α start_POSTSUBSCRIPT Rep end_POSTSUBSCRIPTは表6に記載されている。

Hidden State Alignment.

マスク蒸留損失RepsubscriptRep\mathcal{L}_{\text{Rep}}caligraphic_L start_POSTSUBSCRIPT Rep end_POSTSUBSCRIPTは、生徒モデルと教師モデルの中間表現を整合させるために重要である。回復段階において、生徒モデルの各層は、教師モデルの対応する2層のローカルブロックからの出力隠れ状態を複製するように設計されている。深さの剪定は層の内部次元を変更しないため、追加の投影層なしで直接整合が可能である。SiTsのような、その独自の補間子ベースのアーキテクチャにより隠れ状態の損失がより顕著なモデルの場合、潜在的な訓練の不安定性を緩和するために、β𝛽\betaitalic_βに対してより小さな係数が適用される。訓練を通じてRepsubscriptRep\mathcal{L}_{\text{Rep}}caligraphic_L start_POSTSUBSCRIPT Rep end_POSTSUBSCRIPTを徐々に減少させることで、β𝛽\betaitalic_βの収束への悪影響のリスクをさらに軽減する。

Iterative Pruning and Distillation.

7は、反復的な枝刈りと教師選択戦略の有効性を評価している。TinyDiT-D7を得るためには、28層のDiT-XLを直接枝刈りするか、まずTinyDiT-D14を作成し、その後反復的に小さなモデルを生成するかのいずれかの方法がある。教師の選択と学生モデルの初期重みを得る方法の影響を調査するために、我々は事前学習済みモデルと作成された中間モデルの両方を枝刈りすることでTinyDiT-D7の初期重みを導出した。その後、学習済みモデルと作成されたモデルの両方を、枝刈りされた学生モデルの教師として使用した。4つの実験設定において、作成された中間モデルを用いた枝刈りと蒸留が最も良い性能を示した。特筆すべきは、作成されたモデルから枝刈りされたモデルが、蒸留プロセスで使用された教師モデルに関係なく、事前学習済みモデルから枝刈りされたモデルを上回ったことである。我々はこの優れた性能を2つの要因に帰している。第一に、作成されたモデルの構造が知識蒸留に適しているのは、蒸留法を用いて学習されたためである。第二に、探索空間が縮小されることで、学生モデルのより有利な初期状態を見つけやすくなっている。

Teacher Model Pruned From IS FID sFID Prec. Recall
DiT-XL/2 DiT-XL/2 29.46 56.18 26.03 0.43 0.51
DiT-XL/2 TinyDiT-D14 51.96 36.69 28.28 0.53 0.59
TinyDiT-D14 DiT-XL/2 28.30 58.73 29.53 0.41 0.50
TinyDiT-D14 TinyDiT-D14 57.97 32.47 26.05 0.55 0.60
表7: TinyDiT-D7は、異なる教師モデルを用いて10k回枝刈りと蒸留が行われ、サンプルステップは64、サンプリングにはEMAではなく元の重みが使用されている。

9 Analytical Experiments

Training Strategies

14 は、標準的なファインチューニングと知識蒸留(KD)の有効性を示している。ここでは、DiT-XLを14層に削減し、その後さまざまなファインチューニング手法を適用している。図3は、100Kから500KステップにわたるFIDスコアを示している。標準的なファインチューニング手法により、TinyDiT-D14がDiT-Lと同等の性能を達成しつつ、より高速な推論を提供できることは明らかである。さらに、我々は蒸留の顕著な有効性を確認した。これにより、モデルは100KステップでわずかにDiT-Lを上回り、500Kの標準ファインチューニングを行ったTinyDiT-D14よりも優れたFIDスコアを達成することができる。これは、隠れ層の蒸留がより強力な監督を提供するためである。訓練ステップをさらに500Kまで増やすと、著しく良好な結果が得られる。

Refer to caption
図14: FIDと訓練ステップ。

Learning Rate.

我々は、表 8 に示すように、学習率などのいくつかの重要なハイパーパラメータについても探索を行った。lr=2e-4の有効性を確認し、すべてのモデルと実験にこれを適用した。

Learning Rate IS FID sFID Prec. Recall
lr=2e-4 207.27 3.73 5.04 0.8127 0.5401
lr=1e-4 194.31 4.10 5.01 0.8053 0.5413
lr=5e-5 161.40 6.63 6.69 0.7419 0.5705
表8: 知識蒸留を行わないTinyDiT-D14ファインチューニングにおける学習率の効果

10 Visulization

16 および 16 は、公式チェックポイントから圧縮されたTinySiT-D14とTinyMAR-D16から生成された画像を示している。これらのモデルは、それぞれ元の事前学習コストの7%と10%のみを使用して学習され、提案されたマスク付き知識蒸留法を用いて蒸留された。圧縮にもかかわらず、これらのモデルは深さの50%のみで妥当な結果を生成することが可能である。

11 Limitations

本稿では、条件付き画像生成のための拡散トランスフォーマーモデルを加速するための学習可能な深さ剪定手法を探究している。拡散トランスフォーマーはテキストから画像への生成において著しい進歩を示しているため、テキストから画像へのタスク内でのレイヤー除去の影響を体系的に分析することは価値がある。さらに、注目すべき他の深さ剪定戦略も存在する。例えば、トランスフォーマーブロック全体を除去するのではなく、注意層とMLP層を独立して除去するようなより細かい粒度の剪定戦略などである。我々はこれらの調査を今後の研究課題として残している。

Refer to caption
図15: TinySiT-D14から生成された画像
Refer to caption
図16: TinyMAR-D16から生成された画像