import { DataTexture, FloatType, RedFormat, ShaderMaterial } from "three";
import pointVert from "./shaders/point.vert";
import { COLOR_RAMP } from "./utils";

export const COLOR_MODE = {
    rgb: { key: 1, name: "RGB" },
    classification: { key: 2, name: "Classification" },
    intensity: { key: 3, name: "Intensity" },
    sourceID: { key: 4, name: "Point Source ID" },
    metaIndex: { key: 5, name: "File" },
    date: { key: 6, name: "Date" },
} as const;

export type ColorMode = keyof typeof COLOR_MODE;

export function getColorModeById(id: number) {
    const hit = Object.entries(COLOR_MODE).find(([key, value]) => value.key === id);
    if (hit) {
        return hit[0] as ColorMode;
    } else {
        throw new Error(`Unknown color mode ID: ${id}`);
    }
}

const LOOKUP_TEXTURE_SIZE = 256;
const yearLookupData = new Float32Array(LOOKUP_TEXTURE_SIZE * LOOKUP_TEXTURE_SIZE);
const yearLookupTexture = new DataTexture(
    yearLookupData,
    LOOKUP_TEXTURE_SIZE,
    LOOKUP_TEXTURE_SIZE,
    RedFormat,
    FloatType,
);

export class CustomPointMaterial extends ShaderMaterial {
    constructor() {
        super({
            glslVersion: "300 es",
            uniforms: {
                pointSize: { value: 5.0 },
                colorMode: { value: COLOR_MODE["rgb"] },
                classificationMask: { value: 0xffffffff },
                colorRamp: { value: COLOR_RAMP.flatMap((c) => c.map((v) => v / 255)) },
                colorRampLength: { value: COLOR_RAMP.length },
                indexToYearLookup: { value: yearLookupTexture },
                intensityRange: { value: [0, 1] },
                zRange: { value: [-Infinity, Infinity] },
                filteredPoseId: { value: new Uint32Array(256) },
                filteredPoseIdLength: { value: 0 },
            },
            vertexShader: pointVert,
            fragmentShader: `
				out vec4 FragColor;

				in vec3 vColor;
				flat in int vDiscard;

				void main() {

					if (vDiscard == 1) {
						discard;
					}

					// make the points circular
					float u = 2.0 * gl_PointCoord.x - 1.0;
					float v = 2.0 * gl_PointCoord.y - 1.0;
					if (u * u + v * v > 1.0) {
						discard;
					}

					FragColor = vec4(vColor, 1.0);
				}
			`,
        });
    }

    static updateYearLookup(indexToYear: number[]) {
        if (indexToYear.length >= LOOKUP_TEXTURE_SIZE * LOOKUP_TEXTURE_SIZE) {
            // TODO: expand the texture size by making it use multiple rows
            throw new Error(`too many files (${indexToYear.length}) for the index-to-year lookup table texture`);
        }

        for (let i = 0; i < indexToYear.length && i < yearLookupData.length; i++) {
            const x = i % LOOKUP_TEXTURE_SIZE;
            const y = Math.floor(i / LOOKUP_TEXTURE_SIZE);
            const textureIndex = y * LOOKUP_TEXTURE_SIZE + x;
            yearLookupData[textureIndex] = indexToYear[i];
        }

        yearLookupTexture.needsUpdate = true;
    }

    /** NOTE: Range should be given in raw point coordinates.
     * That means it should not have any possible offset applied.
     *
     * @param zRange [minZ, maxZ] range of z values to render, in raw point coordinates without any offset.
     */
    setVisibleZRange(zRange: number[]) {
        this.uniforms.zRange.value = zRange;
    }

    setColorMode(colorMode: ColorMode) {
        this.uniforms.colorMode.value = COLOR_MODE[colorMode].key;
    }

    setClassificationMask(mask: number) {
        this.uniforms.classificationMask.value = mask;
    }

    adjustPointSize(delta: number) {
        this.uniforms.pointSize.value = Math.max(1.0, this.uniforms.pointSize.value + delta);
    }

    setIntensityRange(min: number, max: number) {
        const u16_max = 65535;
        this.uniforms.intensityRange.value = [min / u16_max, max / u16_max];
    }

    setFilteredPoseIDs(ids: Set<number>) {
        this.uniforms.filteredPoseIdLength.value = ids.size;
        let i = 0;
        for (const id of ids) {
            this.uniforms.filteredPoseId.value[i++] = id;
        }
    }
}
