import React, { useEffect, useRef, useState } from "react";
import * as d3 from "d3";

const humanEvalFailureIds = [
  154, 75, 141, 108, 130, 132, 118, 151, 163, 99, 155, 125, 129, 119, 145, 120,
  83,
];

const SkillValueEmbedding2DView = ({
  datapoints,
  tagColorMap,
  selectedDatapointID,
  setSelectedDatapointID,
  hoveredTag,
  hoveredDatapoint,
  setHoveredDatapoint,
  failureMode,
  selectedModels,
}) => {
  const d3Container = useRef(null);
  const [svgInitialized, setSvgInitialized] = useState(false);

  const [zoomScale, setZoomScale] = useState(1);

  const circle_radius = 5;

  useEffect(() => {
    if (!svgInitialized && d3Container.current) {
      const svg = d3.select(d3Container.current);
      const svgWidth = d3Container.current.clientWidth;
      const svgHeight = d3Container.current.clientHeight;

      // Setup scales
      const xScale = d3.scaleLinear().domain([-2, 2]).range([0, svgWidth]);
      const yScale = d3.scaleLinear().domain([-2, 2]).range([svgHeight, 0]);

      // Append a group element to the SVG for the zoom functionality
      const g = svg.append("g");

      // Create grid lines
      const xGrid = g
        .append("g")
        .attr("class", "x grid")
        .attr("transform", `translate(0,${svgHeight})`)
        .call(
          d3.axisBottom(xScale).ticks(20).tickSize(-svgHeight).tickFormat("")
        );

      const yGrid = g
        .append("g")
        .attr("class", "y grid")
        .call(d3.axisLeft(yScale).ticks(20).tickSize(-svgWidth).tickFormat(""));

      xGrid.selectAll("line").attr("stroke", "lightgrey");

      yGrid.selectAll("line").attr("stroke", "lightgrey");

      // Define zoom behavior
      const zoom = d3
        .zoom()
        .scaleExtent([1, 100])
        .translateExtent([
          [0, 0],
          [svgWidth, svgHeight],
        ])
        .on("zoom", (event) => {
          g.attr("transform", event.transform);
          setZoomScale(event.transform.k);
          g.selectAll("circle").attr("r", circle_radius / event.transform.k); // Adjust this value as needed

          // Update grid lines during zoom
          xGrid
            .selectAll("line")
            .attr("stroke", "lightgrey")
            .attr("stroke-width", 1 / event.transform.k); // Decrease stroke width as zooming in

          yGrid
            .attr("stroke-opacity", 1 / event.transform.k)
            .selectAll("line")
            .attr("stroke", "lightgrey")
            .attr("stroke-width", 1 / event.transform.k);
        });

      // Apply zoom behavior to the SVG element
      svg.call(zoom);

      setSvgInitialized(true);
    }
  }, [svgInitialized]);

  useEffect(() => {
    if (svgInitialized && d3Container.current) {
      const svg = d3.select(d3Container.current);
      const g = svg.select("g");

      // Use scales from initial setup, re-calculate if domain/range changes are needed
      const xScale = d3
        .scaleLinear()
        .domain([-2, 2])
        .range([0, d3Container.current.clientWidth]);
      const yScale = d3
        .scaleLinear()
        .domain([-2, 2])
        .range([d3Container.current.clientHeight, 0]);

      // Update pattern for circles
      const circles = g
        .selectAll("circle")
        .data(Object.values(datapoints), (d) => d.id); // Use 'id' for object constancy

      // Enter new circles
      circles
        .enter()
        .append("circle")
        .attr("r", circle_radius / zoomScale) // Adjust this value as needed
        .merge(circles) // Merge with existing circles for updates
        .attr("cx", (d) => xScale(d.vector[0]))
        .attr("cy", (d) => yScale(d.vector[1]))
        .attr("fill", (d) => {
          if (
            selectedModels.length !== 0 &&
            d.data_row.id &&
            humanEvalFailureIds.includes(parseInt(d.data_row.id.split("/")[1]))
          ) {
            return "red";
          } else if (d.tags_list.includes(hoveredTag)) {
            const darkerColor = tagColorMap[hoveredTag];

            return darkerColor;
          } else {
            if (failureMode && d.failure_id) {
              return "red";
            } else {
              return "black";
            }
          }
        }) // Update fill color dynamically
        .on("mouseover", (event, d) => {
          if (event.buttons === 0) setHoveredDatapoint(d);
        })
        .on("mouseout", () => setHoveredDatapoint(null))
        .on("click", (event, d) => {
          if (event.buttons === 0) setSelectedDatapointID(d.id);
        });

      // Remove circles no longer in the data
      circles.exit().remove();
    }
  }, [
    datapoints,
    tagColorMap,
    hoveredTag,
    failureMode,
    svgInitialized,
    setHoveredDatapoint,
    setSelectedDatapointID,
    selectedModels,
  ]);

  return (
    <svg ref={d3Container} style={{ width: "100%", height: "100%" }}></svg>
  );
};

export default SkillValueEmbedding2DView;
