import { useTexture } from "@react-three/drei";
import React, { useMemo, useEffect } from "react";
import * as THREE from "three";
import { DatasetGeometry } from "~dataset/DatasetGeometry";

import pointTextureUrl from "~/assets/textures/sphere.png";
import { useThree } from "@react-three/fiber";

class PointsMaterial extends THREE.RawShaderMaterial {
  constructor() {
    super({
      transparent: false,
      depthFunc: THREE.LessDepth,
      vertexShader: /* glsl */ `#version 300 es
        uniform mat4 projectionMatrix;
        uniform mat4 modelViewMatrix;
        uniform vec2 resolution;
        uniform float pointSize;
        uniform sampler2D colors;
        in vec3 position;
        in float color;
        in uint id;
        out highp vec3 vColor;
        out float vPointSize;

        void main() {
          gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
          gl_PointSize = pointSize;
          vPointSize = pointSize;
          vColor = texture(colors, vec2(color, 0.5)).rgb;
        }
      `,
      fragmentShader: /* glsl */ `#version 300 es
        precision highp float;
        uniform vec2 resolution;
        uniform float edgeSize;
        uniform sampler2D dotTexture;
        in vec3 vColor;
        in float vPointSize;
        out vec4 fragColor;

        void main() {
          float alpha = texture(dotTexture, gl_PointCoord).r;
          fragColor = vec4(vColor * alpha, 1.0);
          float pixelSize = 1.0 / vPointSize;
          float r = length(gl_PointCoord - vec2(0.5));

          if (alpha < 0.1) discard;
          fragColor.a = alpha;

          float edgeAlpha = smoothstep(0.5 - edgeSize * pixelSize, 0.5, 1.0 - edgeSize * 2.0 * pixelSize - r);

          vec3 edgeColor = vec3(1.0);

          if (edgeSize > 0.5) {
            fragColor.rgb = mix(vColor, edgeColor, 1.0 - edgeAlpha);
          }
        }
      `,
      uniforms: {
        pointSize: new THREE.Uniform(10),
        colors: new THREE.Uniform(null),
        dotTexture: new THREE.Uniform(null),
        edgeSize: new THREE.Uniform(0),
      },
    });
  }
}

const PointsViewInstance = ({
  geometry,
  pointSize = 10,
  edgeSize = 0,
  colorsTexture,
  dotTexture,
  selection,
  depthFunc = THREE.LessDepth,
  renderOrder = 0,
}: {
  geometry: DatasetGeometry;
  pointSize?: number;
  edgeSize?: number;
  colorsTexture?: THREE.Texture;
  dotTexture?: THREE.Texture;
  selection?: number[];
  depthFunc?: THREE.DepthModes;
  renderOrder?: number;
}) => {
  const { invalidate } = useThree();

  const instanceGeometry = useMemo(() => {
    const instanceGeometry = new THREE.BufferGeometry();
    instanceGeometry.attributes = geometry.attributes;
    return instanceGeometry;
  }, [geometry]);

  const material = useMemo(() => new PointsMaterial(), []);

  useEffect(() => {
    material.uniforms.pointSize.value = pointSize;
    material.uniforms.colors.value = colorsTexture || null;
    material.uniforms.dotTexture.value = dotTexture || null;
    material.uniforms.edgeSize.value = edgeSize;
    material.depthFunc = depthFunc;
    material.needsUpdate = true;
    invalidate();
  }, [
    material,
    pointSize,
    colorsTexture,
    dotTexture,
    depthFunc,
    edgeSize,
    invalidate,
  ]);

  // Update groups
  useEffect(() => {
    instanceGeometry.attributes = geometry.attributes;
    instanceGeometry.setIndex(selection ?? null);
    invalidate();
  }, [geometry, selection]);

  // Live reload material update
  useEffect(() => {
    material.version++;
  }, ["hot"]);

  return (
    <points
      frustumCulled={false}
      geometry={instanceGeometry}
      material={material}
      renderOrder={renderOrder}
    />
  );
};

export const PointsView = ({
  geometry,
  pointSize = 10,
  selected,
  highlighted,
  colorsTexture,
}: {
  geometry: DatasetGeometry;
  pointSize?: number;
  colorsTexture?: THREE.Texture;
  dotTexture?: THREE.Texture;
  selected?: number[];
  highlighted?: number[];
  edgeSize?: boolean;
}) => {
  const dotTexture = useTexture(pointTextureUrl);

  return (
    <group>
      <PointsViewInstance
        geometry={geometry}
        pointSize={pointSize}
        colorsTexture={colorsTexture}
        dotTexture={dotTexture}
      />
      <PointsViewInstance
        geometry={geometry}
        pointSize={pointSize + 4}
        edgeSize={3}
        colorsTexture={colorsTexture}
        dotTexture={dotTexture}
        selection={selected}
        depthFunc={THREE.AlwaysDepth}
        renderOrder={1}
      />
      <PointsViewInstance
        geometry={geometry}
        pointSize={pointSize + 16}
        edgeSize={4}
        colorsTexture={colorsTexture}
        dotTexture={dotTexture}
        selection={highlighted}
        depthFunc={THREE.AlwaysDepth}
        renderOrder={2}
      />
    </group>
  );
};
