import React, { useState, useEffect } from "react";
import { Divider, Typography, Box, Chip, Stack } from "@mui/material";
import CircularProgress from "@mui/material/CircularProgress";
import getDatapoint from "../../../../functions/dataset_analysis_demo_functions/getDatapoint";
import { ImageList, ImageListItem } from "@mui/material";
import { DataGrid, GridRowsProp } from "@mui/x-data-grid";
import DatumTagTree from "./DatumTagTree";
import Skeleton from "@mui/material/Skeleton";

const humanEvalFailureIds = [
  154, 75, 141, 108, 130, 132, 118, 151, 163, 99, 155, 125, 129, 119, 145, 120,
  83,
];

const ImageDataGrid = ({ imageList }) => {
  return (
    <ImageList
      sx={{
        width: "100%",
        overflowY: "hidden",
      }}
      cols={3}
      gap={8}
    >
      {imageList.map((img, index) => (
        <ImageListItem key={index}>
          <img
            src={img}
            alt={img}
            loading="lazy"
            style={{
              borderRadius: "4px",
              border: "1px solid #e0e0e0",
            }}
          />
        </ImageListItem>
      ))}
    </ImageList>
  );
};

const CaptionDataGrid = ({ captionList }) => {
  const [pageSize, setPageSize] = useState(5);
  const [expandedRows, setExpandedRows] = useState({});

  // Add a state to track the row heights
  const [rowHeights, setRowHeights] = useState({});

  const rows = captionList.map((caption, index) => ({
    id: index,
    index: index + 1,
    text: caption,
  }));

  const standardHeight = 32;

  const columns = [
    { field: "index", headerName: "ID", flex: 1 },
    {
      field: "text",
      headerName: "Text",
      flex: 3,
      renderCell: (params) => {
        const isExpanded = expandedRows[params.row.id];
        return (
          <div
            style={{
              whiteSpace: isExpanded ? "normal" : "nowrap",
              overflow: "hidden",
              textOverflow: "ellipsis",
              lineHeight: "normal",
              cursor: "pointer",
            }}
            onClick={() => {
              const newExpandedRows = {
                ...expandedRows,
                [params.row.id]: !expandedRows[params.row.id],
              };
              setExpandedRows(newExpandedRows);
              // Update the row heights
              const newHeights = { ...rowHeights };
              if (newExpandedRows[params.row.id]) {
                newHeights[params.row.id] = "auto"; // 'auto' for dynamic height
              } else {
                newHeights[params.row.id] = standardHeight; // Default row height
              }
              setRowHeights(newHeights);
            }}
          >
            {params.value}
          </div>
        );
      },
    },
  ];

  return (
    <div style={{ width: "100%" }}>
      <DataGrid
        rows={rows}
        columns={columns}
        pageSize={pageSize}
        onPageSizeChange={(newPageSize) => setPageSize(newPageSize)}
        density="compact"
        hideFooter
        // Specify a function that returns the height for each row
        getRowHeight={(params) => rowHeights[params.id] || standardHeight} // Return the height from state or default to 52
      />
    </div>
  );
};

const DatapointInformation = ({
  projectId,
  tagColorMap,
  selectedDatapointID,
  circleGraph,
}) => {
  const [selectedDatapoint, setSelectedDatapoint] = useState(null);

  useEffect(() => {
    const fetchDatapoint = async () => {
      setSelectedDatapoint(null);
      if (selectedDatapointID) {
        const datapoint = await getDatapoint({
          id: selectedDatapointID,
          projectId,
        });
        console.log(datapoint);
        setSelectedDatapoint(datapoint);
      } else {
        setSelectedDatapoint(null);
      }
    };
    fetchDatapoint();
  }, [selectedDatapointID, setSelectedDatapoint]);

  return (
    <Box
      sx={{
        width: "100%",
        height: "100%",
        overflow: "auto",
        padding: 2,
      }}
    >
      {selectedDatapoint ? (
        <Box
          sx={{
            position: "relative",
            margin: "auto",
          }}
        >
          <Stack
            spacing={1}
            divider={<Divider orientation="horizontal" flexItem />}
          >
            <Box>
              {" "}
              <Box display="flex">
                {" "}
                <Typography variant="body" component="div">
                  ID {selectedDatapoint.data_row.id}
                </Typography>
                {/* Check if the ID is in humanEvalFailureIds*/}
                {humanEvalFailureIds.includes(
                  parseInt(selectedDatapoint.data_row.id.split("/")[1])
                ) && (
                  <Chip
                    label="Failure Case"
                    color="error"
                    sx={{
                      ml: 1,
                      borderRadius: "4px",
                      height: "20px",
                      opacity: ".8",
                    }}
                  />
                )}
              </Box>
              <Typography
                gutterBottom
                variant="caption"
                color="grey"
                component="div"
              >
                Acadia ID {selectedDatapoint.id}
              </Typography>
            </Box>

            <Box>
              <Typography variant="body" component="div" noWrap>
                Image(s)
              </Typography>
              <ImageDataGrid imageList={selectedDatapoint.img_path_list} />
            </Box>
            <Box>
              <CaptionDataGrid captionList={selectedDatapoint.caption_list} />
            </Box>
            <Box>
              <Typography variant="body" component="div" noWrap>
                Topic Tags
              </Typography>
              <DatumTagTree
                circleGraph={circleGraph}
                tagsList={selectedDatapoint.tags_list}
                leafTagValueDict={selectedDatapoint.leaf_tag_value_dict}
              />
            </Box>
          </Stack>
        </Box>
      ) : (
        <Stack spacing={1}>
          <Skeleton variant="rounded" width="100px" height="30px" />
          <Skeleton variant="rounded" width="200px" height="15px" />
          <Stack spacing={1} direction="row">
            <Skeleton variant="rectangular" width="200px" height="200px" />
            <Skeleton variant="rectangular" width="200px" height="200px" />
          </Stack>
          <Skeleton variant="text" sx={{ fontSize: "1rem" }} />
        </Stack>
      )}
    </Box>
  );
};

export default DatapointInformation;
