import { DataTexture, FloatType, RedFormat, ShaderMaterial } from "three";
import pointVert from "./shaders/point.vert";
import { COLOR_RAMP } from "./utils";

export const COLOR_MODE = {
    rgb: 1,
    classification: 2,
    intensity: 3,
    metaIndex: 4,
    date: 5,
} as const;

export const COLOR_MODE_NAMES = {
    rgb: "RGB",
    classification: "Class",
    intensity: "Intensity",
    metaIndex: "Source",
    date: "Date",
} as const;

const yearLookupData = new Float32Array(2048);
const yearLookupTextureture = new DataTexture(yearLookupData, 2048, 1, RedFormat, FloatType);

export class CustomPointMaterial extends ShaderMaterial {
    constructor() {
        super({
            glslVersion: "300 es",
            uniforms: {
                pointSize: { value: 5.0 },
                colorMode: { value: COLOR_MODE["rgb"] },
                classificationMask: { value: 0xffffffff },
                colorRamp: { value: COLOR_RAMP.flatMap((c) => c.map((v) => v / 255)) },
                colorRampLength: { value: COLOR_RAMP.length },
                indexToYearLookup: { value: yearLookupTextureture },
            },
            vertexShader: pointVert,
            fragmentShader: `
				out vec4 FragColor;

				in vec3 vColor;
				flat in int vDiscard;

				void main() {

					if (vDiscard == 1) {
						discard;
					}

					// make the points circular
					float u = 2.0 * gl_PointCoord.x - 1.0;
					float v = 2.0 * gl_PointCoord.y - 1.0;
					if (u * u + v * v > 1.0) {
						discard;
					}

					FragColor = vec4(vColor, 1.0);
				}
			`,
        });
    }

    static updateYearLookup(indexToYear: number[]) {
        // update the lookup data in-place
        for (let i = 0; i < indexToYear.length && i < yearLookupData.length; i++) {
            yearLookupData[i] = indexToYear[i];
        }
        yearLookupTextureture.needsUpdate = true;
    }

    setColorMode(colorMode: keyof typeof COLOR_MODE) {
        this.uniforms.colorMode.value = COLOR_MODE[colorMode];
    }

    setClassificationMask(mask: number) {
        this.uniforms.classificationMask.value = mask;
    }

    adjustPointSize(delta: number) {
        this.uniforms.pointSize.value = Math.max(1.0, this.uniforms.pointSize.value + delta);
    }
}
