Vala プログラミング

WebGPU プログラミング

おなが@京都先端科学大

Babylon.js WebGPU Fluid MAC ( Chrome Canary )

Babylon.js WebGPU で流体シミュレーションを行なってみました。

計算には、MAC法を用いています。
以下のサイトと本を参考にしました。
1 「数学とか語学とか楽しいよね」
 【Navier-Stokes方程式MAC法によるNavier-Stokes方程式の離散化
  https://mathlang.hatenablog.com/entry/2018/11/14/001001
 【差分法】MAC法で中心差分を用いてNavier-Stokes方程式を解きました
    C++コード付き
2 「流れ解析(東京工芸大学)」
  http://www.cs.t-kougei.ac.jp/nsim/study/flow.htm
  速度-圧力法でキャビティ流れ
3 酒井幸市著「WebGLによる「流れ」と「波」のシミュレーション」( 工学社 )

上のサイト等には、Navier-Stokes方程式MAC法、差分法について詳しく
書いてありますので、ここでは式等は省略しています。
また、実際のプログラム記述には、サンプルプログラムが大変参考になります。
WebGPUのプログラムより、300行ほど短く書けています。

実行結果
f:id:onagat12:20210529193629g:plain

プログラム
Fluid-MAC-2d.html

<!DOCTYPE html>
<html>
<head>
    <title>Babylon.js WebGPU Fluid MAC 2D</title>
    <script src="https://preview.babylonjs.com/babylon.js"></script>
</head>
<body>
<canvas id="renderCanvas" width="400" height="400"></canvas>

<script>
async function init() {
    const canvas = document.getElementById("renderCanvas");
    const engine = new BABYLON.WebGPUEngine(canvas);
    await engine.initAsync();

    const deltaT = 0.002;
    const Re = 100.0;

    const left0 = [-0.8, -0.8];
    const scale0 = 1.4;

    const NX = 21;
    const NY = 21;
    const DX = 1.0 / (NX - 1);
    const DY = 1.0 / (NY - 1);

    const nGridX = NX + 1;
    const nGridY = NY + 1;
    const gridNum = nGridX * nGridY;
    const velX_nGridX = NX + 2;
    const velX_nGridY = NY + 1;
    const velX_gridNum = velX_nGridX * velX_nGridY;
    const velY_nGridX = NX + 1;
    const velY_nGridY = NY + 2;
    const velY_gridNum = velY_nGridX * velY_nGridY;
    const DX2 = DX * DX;
    const DY2 = DY * DY;
    const C1 = 0.5 * DY2 / (DX2 + DY2);
    const C2 = 0.5 * DX2 / (DX2 + DY2);
    const C3 = 0.5 * DY2 / (1.0 + DY2/DX2);

    const numParticles = (NX-1)*(NY-1);

    var createScene = function () {
        var scene = new BABYLON.Scene(engine);

        var camera = new BABYLON.ArcRotateCamera("camera",
            -Math.PI / 2, Math.PI / 2, 10, BABYLON.Vector3.Zero(), scene);
        camera.setTarget(BABYLON.Vector3.Zero());
        camera.attachControl(canvas, true);

        const sim = new FluidSim(numParticles, scene);

        scene.onBeforeRenderObservable.add(() => {
            sim.update();
        });

        return scene;
    };

    class FluidSim {

        constructor(numParticles, scene) {
            const engine = scene.getEngine();

            this.numParticles = numParticles;

            const pointMesh = BABYLON.MeshBuilder.CreatePlane("plane", { size: 1 }, scene);

            this.mesh = pointMesh;
            pointMesh.forcedInstanceCount = numParticles;

            const mat = new BABYLON.ShaderMaterial("mat", scene, { 
                vertexSource: renderShader.point_vs,
                fragmentSource: renderShader.point_fs,
            }, {
                attributes: ["a_pos", "a_particlePos"]
            });

            pointMesh.material = mat;

            const side = 0.02;
            const vertex = [
                -side, -side,
                 side, -side,
                 side,  side,
                -side,  side
            ];
            const buffSpriteVertex = new BABYLON.VertexBuffer(
                engine,
                vertex, "a_pos", false, false, 2, false
            );
            
            pointMesh.setIndices([0, 1, 2, 2, 3, 0]);
            pointMesh.setVerticesBuffer(buffSpriteVertex);

            const initialParticleData = new Float32Array(numParticles * 4);
            var k = 0;
            for (let j = 1; j < NY; j++)
            for (let i = 1; i < NX; i++) {
                // position
                initialParticleData[4 * k + 0] = i * DX;
                initialParticleData[4 * k + 1] = j * DY;
                // velocity
                initialParticleData[4 * k + 2] = 0.0;
                initialParticleData[4 * k + 3] = 0.0;

                k++;
            }

            this.particleBuffers = new BABYLON.StorageBuffer(
                engine,
                initialParticleData.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.particleBuffers.update(initialParticleData);

            var gridXVelocity = new Float32Array( velX_gridNum );
            for (let i = 0; i < gridXVelocity.length; i++) {
                gridXVelocity[i] = 0.0;
            }
            var gridYVelocity = new Float32Array( velY_gridNum );
            for (let i = 0; i < gridYVelocity.length; i++) {
                gridYVelocity[i] = 0.0;
            }

            var gridVelocity = new Float32Array( gridNum * 2 );
            for (let i = 0; i < gridVelocity.length; i++) {
                gridVelocity[i] = 0.0;
            }

            var poisRHS = new Float32Array( gridNum );
            for (let i = 0; i < poisRHS.length; i++) {
                poisRHS[i] = 0.0;
            }

            var gridPres = new Float32Array( gridNum );
            for (let i = 0; i < gridPres.length; i++) {
                gridPres[i] = 0.0;
            }

            this.gridXVelocityBuffers = new BABYLON.StorageBuffer(
                engine,
                gridXVelocity.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.gridXVelocityBuffers.update(gridXVelocity);

            this.gridYVelocityBuffers = new BABYLON.StorageBuffer(
                engine,
                gridYVelocity.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.gridYVelocityBuffers.update(gridYVelocity);
    
            this.gridVelocityBuffers = new BABYLON.StorageBuffer(
                engine,
                gridVelocity.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.gridVelocityBuffers.update(gridVelocity);

            this.poisRHSBuffers = new BABYLON.StorageBuffer(
                engine,
                poisRHS.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.poisRHSBuffers.update(poisRHS);

            this.gridPresBuffers = new BABYLON.StorageBuffer(
                engine,
                gridPres.byteLength,
                BABYLON.Constants.BUFFER_CREATIONFLAG_VERTEX | BABYLON.Constants.BUFFER_CREATIONFLAG_WRITE
            );
            this.gridPresBuffers.update(gridPres);

            this.vertexBuffers = new BABYLON.VertexBuffer(
                engine,
                this.particleBuffers.getBuffer(),
                "a_particlePos", false, false, 4, true, 0, 2
            );

            // compute shader settings
            // gridBC
            this.cs_gridBC = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.gridBC
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                }
            });
            this.cs_gridBC.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_gridBC.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);

            // poisRHS
            this.cs_poisRHS = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.poisRHS
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                    "poisRHS":  { group: 0, binding: 2 },
                }
            });
            this.cs_poisRHS.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_poisRHS.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);
            this.cs_poisRHS.setStorageBuffer("poisRHS", this.poisRHSBuffers);

            // presBC
            this.cs_presBC = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.presBC
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                    "gridPres":  { group: 0, binding: 2 },
                }
            });
            this.cs_presBC.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_presBC.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);
            this.cs_presBC.setStorageBuffer("gridPres", this.gridPresBuffers);
            
            // poisson (gridPres)
            this.cs_poisson = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.poisson
            }, { 
                bindingsMapping: {
                    "poisRHS": { group: 0, binding: 0 },
                    "gridPres": { group: 0, binding: 1 },
                }
            });
            this.cs_poisson.setStorageBuffer("poisRHS", this.poisRHSBuffers);
            this.cs_poisson.setStorageBuffer("gridPres", this.gridPresBuffers);

            // update XVel
            this.cs_updateXVel = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.updateXVel
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                    "gridPres":  { group: 0, binding: 2 },
                }
            });
            this.cs_updateXVel.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_updateXVel.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);
            this.cs_updateXVel.setStorageBuffer("gridPres", this.gridPresBuffers);

            // update YVel
            this.cs_updateYVel = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.updateYVel
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                    "gridPres":  { group: 0, binding: 2 },
                }
            });
            this.cs_updateYVel.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_updateYVel.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);
            this.cs_updateYVel.setStorageBuffer("gridPres", this.gridPresBuffers);

            // gridVelocity
            this.cs_gridVel = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.gridVelocity
            }, { 
                bindingsMapping: {
                    "gridXVelocity": { group: 0, binding: 0 },
                    "gridYVelocity": { group: 0, binding: 1 },
                    "gridVelocity":  { group: 0, binding: 2 },
                }
            });
            this.cs_gridVel.setStorageBuffer("gridXVelocity", this.gridXVelocityBuffers);
            this.cs_gridVel.setStorageBuffer("gridYVelocity", this.gridYVelocityBuffers);
            this.cs_gridVel.setStorageBuffer("gridVelocity", this.gridVelocityBuffers);

            // getParticleVelocity
            this.cs_getPVel = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.getParticleVelocity
            }, { 
                bindingsMapping: {
                    "particle": { group: 0, binding: 0 },
                    "gridVelocity": { group: 0, binding: 1 },
                }
            });
            this.cs_getPVel.setStorageBuffer("particle", this.particleBuffers);
            this.cs_getPVel.setStorageBuffer("gridVelocity", this.gridVelocityBuffers);

            // integrate
            this.cs_integrate = new BABYLON.ComputeShader("compute", engine, {
                computeSource: computeShader.integrate
            }, { 
                bindingsMapping: {
                    "particle": { group: 0, binding: 0 },
                }
            });
            this.cs_integrate.setStorageBuffer("particle", this.particleBuffers);
        }

        update() {
            this.cs_gridBC.dispatch(velX_gridNum);
            this.cs_poisRHS.dispatch((NX-1)*(NY-1));
            for (let k = 0; k < 10; k++) {
                this.cs_presBC.dispatch(nGridX);
                this.cs_poisson.dispatch((NX-1)*(NY-1));
            }
            this.cs_updateXVel.dispatch((NX-2)*(NY-1));
            this.cs_updateYVel.dispatch((NX-1)*(NY-2));
            this.cs_gridVel.dispatch(gridNum);
            this.cs_getPVel.dispatch(this.numParticles);
            this.cs_integrate.dispatch(this.numParticles);
            this.mesh.setVerticesBuffer(this.vertexBuffers, false);
        }
    }

    const renderShader = {
    point_vs:`
    attribute vec2 a_pos;
    attribute vec2 a_particlePos;

    const float scale0 = ${scale0};
    const vec2 left0 = vec2(${left0[0]}, ${left0[1]});
    
    void main() {
        vec2 position0 = vec2(
            left0.x + a_particlePos.x * scale0,
            left0.y + a_particlePos.y * scale0
        );
        mat4 scaleMTX = mat4(
            1.0,         0.0,         0.0, 0.0,
            0.0,         1.0,         0.0, 0.0,
            0.0,         0.0,         1.0, 0.0,
            position0.x, position0.y, 0.0, 1.0
        );

        gl_Position = scaleMTX * vec4(a_pos, 0.0, 1.0);
    }`,

    point_fs:`
    void main() {
        gl_FragColor = vec4(1.0, 1.0, 1.0, 1.0);
    }`
    };

    const computeShader = {
    gridBC:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let vxwall: f32 = 1.0;

        var index: u32 = GlobalInvocationID.x;
        if (index >= ${velX_gridNum}u) {
            return;
        }
        
        // velX
        var I: u32 = index % velX_nGridX;
        var j: u32 = index / velX_nGridX;
        // velY
        var i: u32 = index % velY_nGridX;
        var J: u32 = index / velY_nGridX;

        // 左
        // XVel
        gridXVelocity.gridXVel[ 1u + j * velX_nGridX ] = 0.0;
        gridXVelocity.gridXVel[ 0u + j * velX_nGridX ] =
            gridXVelocity.gridXVel[ 2u + j * velX_nGridX ];
        // YVel
        gridYVelocity.gridYVel[ 0u + J * velY_nGridX ] =
            -gridYVelocity.gridYVel[ 1u + J * velY_nGridX ];
        gridYVelocity.gridYVel[ 0u + (NY+1u) * velY_nGridX ] =
            -gridYVelocity.gridYVel[ 1u + (NY+1u) * velY_nGridX ];
        
        // 右
        // XVel
        gridXVelocity.gridXVel[ NX + j * velX_nGridX ] = 0.0;
        gridXVelocity.gridXVel[ (NX + 1u) + j * velX_nGridX ] =
            gridXVelocity.gridXVel[ (NX - 1u) + j * velX_nGridX ];
        // YVel
        gridYVelocity.gridYVel[ NX + J * velY_nGridX ] =
            -gridYVelocity.gridYVel[ (NX - 1u) + J * velY_nGridX ];
        gridYVelocity.gridYVel[ NX + (NY + 1u) * velY_nGridX ] =
            -gridYVelocity.gridYVel[ (NX - 1u) + (NY + 1u) * velY_nGridX ];

        // 下
        // YVel
        gridYVelocity.gridYVel[ i + 1u * velY_nGridX ] = 0.0;
        gridYVelocity.gridYVel[ i + 0u * velY_nGridX ] =
            gridYVelocity.gridYVel[ i + 2u * velY_nGridX ];
        // XVel
        gridXVelocity.gridXVel[ I + 0u * velX_nGridX ] =
            -gridXVelocity.gridXVel[ I + 1u * velX_nGridX ];
        gridXVelocity.gridXVel[ (NX + 1u) + 0u * velX_nGridX ] =
            -gridXVelocity.gridXVel[ NX + 0u * velX_nGridX ];
        
        // 上
        // YVel
        gridYVelocity.gridYVel[ i + NY * velY_nGridX ] = 0.0;
        gridYVelocity.gridYVel[ i + (NY + 1u) * velY_nGridX ] =
            gridYVelocity.gridYVel[ i + (NY - 1u) * velY_nGridX ];
        // XVel
        gridXVelocity.gridXVel[ I + NY * velX_nGridX ] =
            2.0 * vxwall - gridXVelocity.gridXVel[ I + (NY - 1u) * velX_nGridX ];
        gridXVelocity.gridXVel[ (NX + 1u) + NY * velX_nGridX ] =
            -gridXVelocity.gridXVel[ NX + NY * velX_nGridX ];
    }`,

    poisRHS:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    [[block]] struct PoisRHSBuffer {
        prhs: array<f32>;
    };
    [[group(0), binding(2)]]
    var<storage> poisRHS: [[access(read_write)]] PoisRHSBuffer;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let deltaT: f32 = ${deltaT};

        var index: u32 = GlobalInvocationID.x;
        // dispatch (NX-1)*(NY-1)
        if (index >= (NX - 1u)*(NY - 1u)) {
            return;
        }
        
        // index i, j
        var ii: u32 = index % (NX - 1u);
        var jj: u32 = index / (NX - 1u);
        var i: u32 = ii + 1u;
        var j: u32 = jj + 1u;

        // velX
        var XIndexIJ: u32   = i + j * velX_nGridX;
        var XIndexIPJ: u32  = (i + 1u) + j * velX_nGridX;
        var XIndexIPJP: u32 = (i + 1u) + (j + 1u) * velX_nGridX;
        var XIndexIJP: u32  = i + (j + 1u) * velX_nGridX;
        var XIndexIJM: u32  = i + (j - 1u) * velX_nGridX;
        var XIndexIPJM: u32 = (i + 1u) + (j - 1u) * velX_nGridX;

        // velY
        var YIndexIJ: u32   = i + j * velY_nGridX;
        var YIndexIJP: u32  = i + (j + 1u) * velY_nGridX;
        var YIndexIPJ: u32  = (i + 1u) + j * velY_nGridX;
        var YIndexIPJP: u32 = (i + 1u) + (j + 1u) * velY_nGridX;
        var YIndexIMJ: u32  = (i - 1u) + j * velY_nGridX;
        var YIndexIMJP: u32 = (i - 1u) + (j + 1u) * velY_nGridX;

        var udiv: f32 = (gridXVelocity.gridXVel[XIndexIPJ] - gridXVelocity.gridXVel[XIndexIJ]) / DX;
        var vdiv: f32 = (gridYVelocity.gridYVel[YIndexIJP] - gridYVelocity.gridYVel[YIndexIJ]) / DY;

        var ua: f32 = (gridXVelocity.gridXVel[XIndexIJ] + gridXVelocity.gridXVel[XIndexIPJ]
                + gridXVelocity.gridXVel[XIndexIPJP] + gridXVelocity.gridXVel[XIndexIJP]) / 4.0;
        var ub: f32 = (gridXVelocity.gridXVel[XIndexIJ] + gridXVelocity.gridXVel[XIndexIPJ]
                + gridXVelocity.gridXVel[XIndexIPJM] + gridXVelocity.gridXVel[XIndexIJM]) / 4.0;
        var va: f32 = (gridYVelocity.gridYVel[YIndexIJ] + gridYVelocity.gridYVel[YIndexIJP]
                + gridYVelocity.gridYVel[YIndexIPJP] + gridYVelocity.gridYVel[YIndexIPJ]) / 4.0;
        var vb: f32 = (gridYVelocity.gridYVel[YIndexIJ] + gridYVelocity.gridYVel[YIndexIJP]
                + gridYVelocity.gridYVel[YIndexIMJP] + gridYVelocity.gridYVel[YIndexIMJ]) / 4.0;

        poisRHS.prhs[i + j * (NX + 1u)] = -udiv*udiv - 2.0*(ua - ub)*(va - vb)/DX/DY
            -vdiv*vdiv + (udiv + vdiv) / deltaT;
    }`,

    presBC:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    [[block]] struct GridPres {
        pres: array<f32>;
    };
    [[group(0), binding(2)]]
    var<storage> gridPres: [[access(read_write)]] GridPres;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let nGridX: u32 = ${nGridX}u;
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let Re: f32 = f32(${Re});

        var index: u32 = GlobalInvocationID.x;
        // dispatch index i or j , nGridX(= NX + 1 )
        if (index >= nGridX) {
            return;
        }
        
        // 左右 index = j
        var j: u32 = index;
        gridPres.pres[0u + j * nGridX] = gridPres.pres[1u + j * nGridX]
            - 1.0/Re * 2.0 * gridXVelocity.gridXVel[2u + j * velX_nGridX] / DX;

        gridPres.pres[NX + j * nGridX] = gridPres.pres[(NX - 1u) + j * nGridX]
            + 1.0/Re * 2.0 * gridXVelocity.gridXVel[(NX - 1u) + j * velX_nGridX] / DX;

        // 上下 index = i
        var i: u32 = index;
        gridPres.pres[i + 0u * nGridX] = gridPres.pres[i + 1u * nGridX]
            - 1.0/Re * 2.0 * gridYVelocity.gridYVel[i + 2u * velY_nGridX] / DY;
        gridPres.pres[i + NY * nGridX] = gridPres.pres[i + (NY - 1u) * nGridX]
            + 1.0/Re * 2.0 * gridYVelocity.gridYVel[i + (NY - 1u) * velY_nGridX] / DY;
    }`,

    poisson:`
    [[block]] struct PoisRHSBuffer {
        prhs: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> poisRHS: [[access(read_write)]] PoisRHSBuffer;

    [[block]] struct GridPres {
        pres: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridPres: [[access(read_write)]] GridPres;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let nGridX: u32 = ${nGridX}u;
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let Re: f32 = f32(${Re});
        let c1: f32 = ${C1};
        let c2: f32 = ${C2};
        let c3: f32 = ${C3};

        var index: u32 = GlobalInvocationID.x;
        // dispatch (NX-1)*(NY-1)
        if (index >= (NX - 1u)*(NY - 1u)) {
            return;
        }
        
        var ii: u32 = index % (NX - 1u);
        var jj: u32 = index / (NX - 1u);
        var i: u32 = ii + 1u;
        var j: u32 = jj + 1u;

        var pp: f32 = c1 * (gridPres.pres[(i + 1u) + j * nGridX] + gridPres.pres[(i - 1u) + j * nGridX])
            + c2 * (gridPres.pres[i + (j + 1u) * nGridX] + gridPres.pres[i + (j - 1u) * nGridX])
            - c3 * poisRHS.prhs[i + j * nGridX];

        gridPres.pres[i + j * nGridX] = pp;
    }`,

    updateXVel:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    [[block]] struct GridPres {
        pres: array<f32>;
    };
    [[group(0), binding(2)]]
    var<storage> gridPres: [[access(read_write)]] GridPres;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let nGridX: u32 = ${nGridX}u;
        let deltaT: f32 = ${deltaT};
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let Re: f32 = f32(${Re});

        var index: u32 = GlobalInvocationID.x;
        // dispatch index (NX-2)*(NY-1)
        // XVel: i=2 ~ NX-1, j=1 ~ NY-1
        if (index >= (NX - 2u)*(NY - 1u)) {
            return;
        }
        
        // XVel
        var i: u32 = index % (NX - 2u) + 2u;
        var j: u32 = index / (NX - 2u) + 1u;

        var YIndexIJ: u32   = i + j * velY_nGridX;
        var YIndexIJP: u32  = i + (j + 1u) * velY_nGridX;
        var YIndexIMJ: u32  = (i - 1u) + j * velY_nGridX;
        var YIndexIMJP: u32 = (i - 1u) + (j + 1u) * velY_nGridX;
        //
        var XIndexIJ: u32  = i + j * velX_nGridX;
        var XIndexIPJ: u32 = (i + 1u) + j * velX_nGridX;
        var XIndexIMJ: u32 = (i - 1u) + j * velX_nGridX;
        var XIndexIJP: u32 = i + (j + 1u) * velX_nGridX;
        var XIndexIJM: u32 = i + (j - 1u) * velX_nGridX;

        var vmid: f32 = (gridYVelocity.gridYVel[YIndexIJ] + gridYVelocity.gridYVel[YIndexIJP]
            + gridYVelocity.gridYVel[YIndexIMJP] + gridYVelocity.gridYVel[YIndexIMJ]) / 4.0;
        var uad: f32 = gridXVelocity.gridXVel[XIndexIJ] * (gridXVelocity.gridXVel[XIndexIPJ] - gridXVelocity.gridXVel[XIndexIMJ])/2.0/DX
            + vmid * (gridXVelocity.gridXVel[XIndexIJP] - gridXVelocity.gridXVel[XIndexIJM])/2.0/DY;
        var udif: f32 = (gridXVelocity.gridXVel[XIndexIPJ] - 2.0*gridXVelocity.gridXVel[XIndexIJ] + gridXVelocity.gridXVel[XIndexIMJ])/DX/DX
            + (gridXVelocity.gridXVel[XIndexIJP] - 2.0*gridXVelocity.gridXVel[XIndexIJ] + gridXVelocity.gridXVel[XIndexIJM])/DY/DY;

        gridXVelocity.gridXVel[XIndexIJ] = gridXVelocity.gridXVel[XIndexIJ]
            + deltaT * (-uad
            - (gridPres.pres[i + j * nGridX] - gridPres.pres[(i - 1u) + j * nGridX])/DX
            + 1.0/Re*udif);
    }`,

    updateYVel:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    [[block]] struct GridPres {
        pres: array<f32>;
    };
    [[group(0), binding(2)]]
    var<storage> gridPres: [[access(read_write)]] GridPres;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let nGridX: u32 = ${nGridX}u;
        let deltaT: f32 = ${deltaT};
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let Re: f32 = f32(${Re});

        var index: u32 = GlobalInvocationID.x;
        // dispatch index (NX-1)*(NY-2)
        // YVel: i=1 ~ NX-1, j=2 ~ NY-1
        if (index >= (NX - 1u)*(NY - 2u)) {
            return;
        }
        
        // YVel
        var i: u32 = index % (NX - 1u) + 1u;
        var j: u32 = index / (NX - 1u) + 2u;

        var XIndexIJ: u32   = i + j * velX_nGridX;
        var XIndexIPJ: u32  = (i + 1u) + j * velX_nGridX;
        var XIndexIJM: u32  = i + (j - 1u) * velX_nGridX;
        var XIndexIPJM: u32 = (i + 1u) + (j - 1u) * velX_nGridX;

        var YIndexIJ: u32  = i + j * velY_nGridX;
        var YIndexIPJ: u32 = (i + 1u) + j * velY_nGridX;
        var YIndexIMJ: u32 = (i - 1u) + j * velY_nGridX;
        var YIndexIJP: u32 = i + (j + 1u) * velY_nGridX;
        var YIndexIJM: u32 = i + (j - 1u) * velY_nGridX;

        let umid: f32 = (gridXVelocity.gridXVel[XIndexIJ] + gridXVelocity.gridXVel[XIndexIPJ]
            + gridXVelocity.gridXVel[XIndexIPJM] + gridXVelocity.gridXVel[XIndexIJM]) / 4.0;
        let vad: f32 = umid * (gridYVelocity.gridYVel[YIndexIPJ] - gridYVelocity.gridYVel[YIndexIMJ])/2.0/DX
            + gridYVelocity.gridYVel[YIndexIJ]*(gridYVelocity.gridYVel[YIndexIJP] - gridYVelocity.gridYVel[YIndexIJM])/2.0/DY;
        let vdif: f32 = (gridYVelocity.gridYVel[YIndexIPJ] - 2.0*gridYVelocity.gridYVel[YIndexIJ] + gridYVelocity.gridYVel[YIndexIMJ])/DX/DX
                + (gridYVelocity.gridYVel[YIndexIJP] - 2.0*gridYVelocity.gridYVel[YIndexIJ] + gridYVelocity.gridYVel[YIndexIJM])/DY/DY;
        
        gridYVelocity.gridYVel[YIndexIJ] = gridYVelocity.gridYVel[YIndexIJ]
            + deltaT * ( -vad
            - (gridPres.pres[i + j * nGridX] - gridPres.pres[i + (j - 1u) * nGridX])/DY
            + 1.0/Re*vdif);
    }`,

    gridVelocity:`
    [[block]] struct GridXVelBuffer {
        gridXVel: array<f32>;
    };
    [[group(0), binding(0)]]
    var<storage> gridXVelocity: [[access(read_write)]] GridXVelBuffer;
    
    [[block]] struct GridYVelBuffer {
        gridYVel: array<f32>;
    };
    [[group(0), binding(1)]]
    var<storage> gridYVelocity: [[access(read_write)]] GridYVelBuffer;
    
    struct GridVelocity {
        gridVel: vec2<f32>;
    };
    [[block]] struct GridVelocities {
        gridVelocities: array<GridVelocity, ${gridNum}>;
    };
    [[group(0), binding(2)]]
    var<storage> gridVelocity: [[access(read_write)]] GridVelocities;
    
    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let velX_nGridX: u32 = ${velX_nGridX}u;
        let velY_nGridX: u32 = ${velY_nGridX}u;
        let nGridX: u32 = ${nGridX}u;
        let deltaT: f32 = ${deltaT};
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let Re: f32 = f32(${Re});

        var index: u32 = GlobalInvocationID.x;
        // dispatch index (NX+1)*(NY+1) gridNum
        if (index >= ${gridNum}u) {
            return;
        }
        
        var i: u32 = index % nGridX;
        var j: u32 = index / nGridX;

        var velXIndexIJ: u32  = i + j * velX_nGridX;
        var velXIndexIPJ: u32 = (i + 1u) + j * velX_nGridX;
        var velYIndexIJ: u32  = i + j * velY_nGridX;
        var velYIndexIJP: u32 = i + (j + 1u) * velY_nGridX;

        gridVelocity.gridVelocities[i + j * nGridX ].gridVel.x =
            (gridXVelocity.gridXVel[velXIndexIJ] + gridXVelocity.gridXVel[velXIndexIPJ]) / 2.0;
        gridVelocity.gridVelocities[i + j * nGridX ].gridVel.y =
            (gridYVelocity.gridYVel[velYIndexIJ] + gridYVelocity.gridYVel[velYIndexIJP]) / 2.0;
    }`,

    getParticleVelocity:`
    struct Particle {
        pos: vec2<f32>;
        vel: vec2<f32>;
    };
    [[block]] struct Particles {
        particles: array<Particle, ${numParticles}>;
    };
    [[group(0), binding(0)]]
    var<storage> particle: [[access(read_write)]] Particles;

    struct GridVelocity {
        gridVel: vec2<f32>;
    };
    [[block]] struct GridVelocities {
        gridVelocities: array<GridVelocity, ${gridNum}>;
    };
    [[group(0), binding(1)]]
    var<storage> gridVelocity: [[access(read_write)]] GridVelocities;

    fn getParticleVelocity(position: vec2<f32>) -> vec2<f32> {
        let NX: u32 = ${NX}u;
        let NY: u32 = ${NY}u;
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};
        let nGridX: u32 = ${nGridX}u;

        var I: u32;
        var J: u32;

        I = 0u;
        J = 0u;

        for (var i: u32 = 0u; i < NX; i = i + 1u) {
            if (f32(i) * DX < position.x && f32(i+1u) * DX > position.x) { I = i; }
        }
        for (var j: u32 = 0u; j < NY; j = j +1u) {
            if (f32(j) * DY < position.y && f32(j+1u) * DY > position.y) { J = j; }
        }

        var a: f32 =  position.x / DX - f32(I);
        var b: f32 =  position.y / DY - f32(J);
        
        var gIndexIJ00: u32 = I + J * nGridX;
        var gIndexIJ10: u32 = (I + 1u) + J * nGridX;
        var gIndexIJ01: u32 = I + (J + 1u) * nGridX;
        var gIndexIJ11: u32 = (I + 1u) + (J + 1u) * nGridX;

        var velXIJ00: f32 = gridVelocity.gridVelocities[ gIndexIJ00 ].gridVel.x;
        var velXIJ10: f32 = gridVelocity.gridVelocities[ gIndexIJ10 ].gridVel.x;
        var velXIJ01: f32 = gridVelocity.gridVelocities[ gIndexIJ01 ].gridVel.x;
        var velXIJ11: f32 = gridVelocity.gridVelocities[ gIndexIJ11 ].gridVel.x;
        var velYIJ00: f32 = gridVelocity.gridVelocities[ gIndexIJ00 ].gridVel.y;
        var velYIJ10: f32 = gridVelocity.gridVelocities[ gIndexIJ10 ].gridVel.y;
        var velYIJ01: f32 = gridVelocity.gridVelocities[ gIndexIJ01 ].gridVel.y;
        var velYIJ11: f32 = gridVelocity.gridVelocities[ gIndexIJ11 ].gridVel.y;

        var vel_x: f32 = (1.0 - b) * ( (1.0 - a) * velXIJ00 + a * velXIJ10 )
                        + b * ( (1.0 - a) * velXIJ01 + a * velXIJ11 );
        var vel_y: f32 = (1.0 - b) * ( (1.0 - a) * velYIJ00 + a * velYIJ10 )
                        + b * ( (1.0 - a) * velYIJ01 + a * velYIJ11 );

        var vel: vec2<f32> = vec2<f32>(
            vel_x,
            vel_y
        );
        return vel;
    }

    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        var index: u32 = GlobalInvocationID.x;
        if (index >= ${numParticles}u) {
            return;
        }

        var particlePosition: vec2<f32> = particle.particles[index].pos;
        particle.particles[index].vel = getParticleVelocity(particlePosition);
    }`,
    
    integrate:`
    struct Particle {
        pos : vec2<f32>;
        vel : vec2<f32>;
    };
    [[block]] struct Particles {
        particles : array<Particle, ${numParticles}>;
    };
    [[group(0), binding(0)]]
    var<storage> particle : [[access(read_write)]] Particles;

    [[stage(compute)]]
    fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
        let deltaT: f32 = ${deltaT};
        let DX: f32 = ${DX};
        let DY: f32 = ${DY};

        var index : u32 = GlobalInvocationID.x;
        if (index >= ${numParticles}u) {
            return;
        }

        var vPos : vec2<f32> = particle.particles[index].pos;
        var vVel : vec2<f32> = particle.particles[index].vel;

        vPos = vPos + vVel * deltaT;

        if (vPos.x > 1.0) {
            vPos.x = 1.0 - DX/8.0;
        }
        if (vPos.x < 0.0) {
            vPos.x = 0.0 + DX/8.0;
        }

        particle.particles[index].pos = vPos;
    }`
    };

    const scene = createScene();
    engine.runRenderLoop(() => {
        scene.render();
    });
};

init();
</script>
</body>
</html>