import React, {
  useMemo,
  useRef,
  useCallback,
  useState,
  useEffect,
} from "react";
import { Canvas, useThree } from "@react-three/fiber";
import { OrbitControls, PointMaterial } from "@react-three/drei";
import * as THREE from "three";

function darkenAndSaturateColor(hex, darkenAmount = 0.3, saturateAmount = 0.5) {
  // Convert hex to RGB
  let r = parseInt(hex.slice(1, 3), 16);
  let g = parseInt(hex.slice(3, 5), 16);
  let b = parseInt(hex.slice(5, 7), 16);

  // Convert RGB to HSL
  r /= 255;
  g /= 255;
  b /= 255;

  const max = Math.max(r, g, b),
    min = Math.min(r, g, b);
  let h,
    s,
    l = (max + min) / 2;

  if (max === min) {
    h = s = 0; // achromatic
  } else {
    const d = max - min;
    s = l > 0.5 ? d / (2 - max - min) : d / (max + min);
    switch (max) {
      case r:
        h = (g - b) / d + (g < b ? 6 : 0);
        break;
      case g:
        h = (b - r) / d + 2;
        break;
      case b:
        h = (r - g) / d + 4;
        break;
    }
    h /= 6;
  }

  // Darken and increase saturation
  l = Math.max(0, l - darkenAmount); // Ensure lightness doesn't go below 0
  s = Math.min(1, s + saturateAmount); // Ensure saturation doesn't go above 100%

  // Convert HSL back to RGB
  let hue2rgb = (p, q, t) => {
    if (t < 0) t += 1;
    if (t > 1) t -= 1;
    if (t < 1 / 6) return p + (q - p) * 6 * t;
    if (t < 1 / 2) return q;
    if (t < 2 / 3) return p + (q - p) * (2 / 3 - t) * 6;
    return p;
  };

  let q = l < 0.5 ? l * (1 + s) : l + s - l * s;
  let p = 2 * l - q;
  r = hue2rgb(p, q, h + 1 / 3);
  g = hue2rgb(p, q, h);
  b = hue2rgb(p, q, h - 1 / 3);

  // Convert to hexadecimal
  r = Math.round(r * 255)
    .toString(16)
    .padStart(2, "0");
  g = Math.round(g * 255)
    .toString(16)
    .padStart(2, "0");
  b = Math.round(b * 255)
    .toString(16)
    .padStart(2, "0");

  return `#${r}${g}${b}`;
}

function Particles({
  data,
  setHoveredDatapoint,
  handleClick,
  hoveredTag,
  tagColorMap,
  failureMode,
}) {
  const { size } = useThree();
  const points = useRef();

  const [positions, colors] = useMemo(() => {
    const positions = data.flatMap(({ vector }) => vector.slice(0, 3));

    const colors = data.flatMap(({ tags_list, failure_id }) => {
      if (tags_list.includes(hoveredTag)) {
        const darkerColor = darkenAndSaturateColor(tagColorMap[hoveredTag]);
        const colorArray = new THREE.Color(darkerColor).toArray();

        return colorArray;
      } else {
        if (failureMode && failure_id) {
          return new THREE.Color("red").toArray();
        } else {
          return new THREE.Color("black").toArray();
        }
      }
    });

    return [new Float32Array(positions), new Float32Array(colors)];
  }, [data, hoveredTag, failureMode]);

  const hover = useCallback(
    (e) => {
      if (e.buttons === 0) {
        e.stopPropagation();
        // Temporary color change on hover (example: white)
        new THREE.Color("hotpink").toArray(
          points.current.geometry.attributes.color.array,
          e.index * 3
        );

        points.current.geometry.attributes.color.needsUpdate = true;

        const { index, unprojectedPoint } = e;

        const datapoint = data[index];

        setHoveredDatapoint(datapoint);
      }
    },
    [data, setHoveredDatapoint]
  );

  const unhover = useCallback(
    (e) => {
      const datapoint = data[e.index];

      if (failureMode && datapoint.failure_id) {
        new THREE.Color("red").toArray(
          points.current.geometry.attributes.color.array,
          e.index * 3
        );
      } else {
        new THREE.Color("black").toArray(
          points.current.geometry.attributes.color.array,
          e.index * 3
        );
      }

      setHoveredDatapoint(null);
      points.current.geometry.attributes.color.needsUpdate = true;
    },
    [data, setHoveredDatapoint, failureMode]
  );

  // Inside Particles component
  const click = useCallback(
    (e) => {
      e.stopPropagation();
      const { index } = e;
      const datapoint = data[index];
      const id = datapoint.id;
      // Call the function passed via props to set the selected ID
      handleClick(id);
    },
    [data, handleClick] // Make sure handleClick is passed as a prop and included here
  );

  return (
    <points
      ref={points}
      onPointerOver={(e) => {
        hover(e);
      }}
      onPointerOut={unhover}
      onClick={click}
    >
      <bufferGeometry>
        <bufferAttribute attach="attributes-position" args={[positions, 3]} />
        <bufferAttribute attach="attributes-color" args={[colors, 3]} />
      </bufferGeometry>
      <PointMaterial size={size.width / 100} vertexColors />
    </points>
  );
}

const Axes = () => {
  const { scene } = useThree();

  useEffect(() => {
    // Define the length and colors of the axes
    const axesLength = 100;
    // Create geometry for the axes lines
    const axesGeometry = new THREE.BufferGeometry().setFromPoints([
      new THREE.Vector3(-axesLength, 0, 0),
      new THREE.Vector3(axesLength, 0, 0),
      new THREE.Vector3(0, -axesLength, 0),
      new THREE.Vector3(0, axesLength, 0),
      new THREE.Vector3(0, 0, -axesLength),
      new THREE.Vector3(0, 0, axesLength),
    ]);

    // Create materials for each axis with different colors
    const material = new THREE.LineBasicMaterial({ color: "lightgrey" });
    // Create lines for each axis
    const axis = new THREE.LineSegments(axesGeometry, material);

    // Offset each axis line to start from origin and extend in both positive and negative directions

    // Add the axes to the scene
    scene.add(axis);

    // Cleanup function to remove the axes from the scene on component unmount
    return () => {
      scene.remove(axis);
    };
  }, [scene]); // Dependency array includes scene to ensure effect runs once when the scene is ready

  return null;
};

export default function Embedding3DView({
  datapoints,
  tagColorMap,
  selectedDatapointID,
  setSelectedDatapointID,
  hoveredTag,
  hoveredDatapoint,
  setHoveredDatapoint,
  failureMode,
}) {
  const [selectedSemanticViewId, setSelectedSemanticViewId] = useState(null);

  useEffect(() => {
    setSelectedSemanticViewId(selectedDatapointID);
  }, [selectedDatapointID]);

  const handleClick = (id) => {
    if (selectedSemanticViewId && id && id === selectedSemanticViewId) {
      setSelectedSemanticViewId(null);
      setSelectedDatapointID(null);
    } else {
      setSelectedSemanticViewId(id);
      setSelectedDatapointID(id);
    }
  };

  return (
    <>
      <Canvas
        orthographic
        camera={{ zoom: 100, position: [100, 100, 100] }}
        raycaster={{ params: { Points: { threshold: 0.02 } } }}
      >
        <Axes />
        <Particles
          data={datapoints}
          setHoveredDatapoint={setHoveredDatapoint}
          handleClick={handleClick}
          hoveredTag={hoveredTag}
          tagColorMap={tagColorMap}
          failureMode={failureMode}
        />
        <OrbitControls />
      </Canvas>
    </>
  );
}
