import React from "react";
import PropTypes from "prop-types";
import * as d3 from "d3";

import "./scatter-plot.css";
import { round } from "../utils/utils";
import { recordShape } from "../config";

const MARGIN = {
  top: 0, right: 10, bottom: 20, left: 20,
};
const OFFSET = 5;
const LEGEND = { x: 10, y: 10, rowLength: 20 };

function Tooltip(props) {
  const { activeRecord, style } = props;
  if (activeRecord === null) return null;

  const text = (
    `x=${round(activeRecord.emb_x, 1)}; `
    + `y=${round(activeRecord.emb_x, 1)}; `
    + `Genus ID=${activeRecord.genus_id}; `
    + `Expected Genus ID=${activeRecord.pred_genus_id}`
  );
  return (
    <div id={activeRecord.id} className="svg-tooltip" style={style}>
      {text}
    </div>
  );
}

Tooltip.propTypes = {
  activeRecord: recordShape,
  style: PropTypes.object,
};
Tooltip.defaultProps = {
  activeRecord: null,
  style: null,
};

function _getAxisRanges(data) {
  // get embedding dimensions
  let xmin = Infinity; let xmax = -Infinity; let ymin = Infinity; let
    ymax = -Infinity;
  data.forEach((item) => {
    if (item.emb_x < xmin) xmin = item.emb_x;
    if (item.emb_x > xmax) xmax = item.emb_x;
    if (item.emb_y < ymin) ymin = item.emb_y;
    if (item.emb_y > ymax) ymax = item.emb_y;
  });
  const xAxisRange = [xmin - OFFSET, xmax + OFFSET];
  const yAxisRange = [ymin - OFFSET, ymax + OFFSET];
  return { xAxisRange, yAxisRange };
}

export default class ScatterPlot extends React.Component {
  constructor() {
    super();
    this.state = {
      activeRecord: null,
      tooltipPosition: { top: 0, left: 0 },
    };
  }

  componentDidMount() {
    this.plot();
  }

  showTooltip(event) {
    const { onActiveRecord } = this.props;
    const record = event.target.__data__;
    if (typeof record !== "undefined") {
      this.setState({ activeRecord: record }, () => {
        if (typeof onActiveRecord === "function") onActiveRecord(record);
      });
    }
  }

  hideTooltip() {
    const { onActiveRecordLeave } = this.props;
    this.setState({ activeRecord: null }, onActiveRecordLeave);
  }

  moveTooltip(event) {
    this.setState({
      tooltipPosition: {
        top: event.pageY + 10,
        left: event.pageX + 10,
      },
    });
  }

  click(event) {
    const { onClick } = this.props;
    const record = event.target.__data__;
    if (typeof onClick === "function") onClick(record);
  }

  plot() {
    // thanks to https://d3-graph-gallery.com/graph/scatter_basic.html

    // check if the scatter plot was already rendered
    const numRenderedElements = d3.select("#svg").selectAll("g").size();
    if (numRenderedElements > 0) return;

    const {
      data, width, height, onBrush,
    } = this.props;

    // compute x- and y- axis ranges
    const { xAxisRange, yAxisRange } = _getAxisRanges(data);

    // add svg element
    const svg = d3.select("#svg")
      .attr("width", width + MARGIN.left + MARGIN.right)
      .attr("height", height + MARGIN.top + MARGIN.bottom)
      .append("g")
      .attr("transform", `translate(${MARGIN.left},${MARGIN.top})`);

    // add brush (needs to be added before dots, otherwise hover/click events are not working)
    const brush = d3.brush();
    svg.call(brush);

    // add x axis
    const xAxis = d3.scaleLinear()
      .domain(xAxisRange)
      .range([0, width]);
    svg.append("g")
      .attr("transform", `translate(0,${height})`)
      .call(d3.axisBottom(xAxis));

    // add y axis
    const yAxis = d3.scaleLinear()
      .domain(yAxisRange)
      .range([height, 0]);
    svg.append("g")
      .call(d3.axisLeft(yAxis));

    // add legend
    const genusId2color = data.reduce((prevItem, item) => { // get unique items by key
      const key = item.genus_id;
      const value = { color: item.color, text: `(${item.genus_id}) ${item.genus}` };
      return ({ ...prevItem, [key]: value });
    }, {});
    const genusId2colorArr = Object.values(genusId2color);
    const legend = svg.append("g");
    legend.append("rect")
      .attr("x", LEGEND.x)
      .attr("y", LEGEND.y)
      .attr("rx", 8)
      .attr("height", LEGEND.rowLength * genusId2colorArr.length)
      .attr("width", 120)
      .style("stroke", "black")
      .style("stroke-width", 1.5)
      .style("fill", "white");

    legend.selectAll("circle")
      .data(genusId2colorArr)
      .enter()
      .append("circle")
      .attr("cx", LEGEND.x + 8)
      .attr("cy", (d, i) => LEGEND.y + 8 + i * LEGEND.rowLength)
      .attr("r", 4)
      .style("fill", (d) => d.color);

    legend.selectAll("text")
      .data(genusId2colorArr)
      .enter()
      .append("text")
      .attr("x", LEGEND.x + 8 + 8)
      .attr("y", (d, i) => LEGEND.y + 14 + i * LEGEND.rowLength)
      .text((d) => d.text)
      .style("fill", "currentColor")
      .style("font-size", "15px");

    // add dots
    const dot = svg.append("g")
      .selectAll("circle")
      .data(data)
      .enter()
      .append("circle")
      .attr("cx", (d) => xAxis(d.emb_x))
      .attr("cy", (d) => yAxis(d.emb_y))
      .attr("r", 1.5)
      .style("fill", (d) => d.color)
      .on("mouseover", (e) => this.showTooltip(e))
      .on("mouseleave", (e) => this.hideTooltip(e))
      .on("mousemove", (e) => this.moveTooltip(e))
      .on("click", (e) => this.click(e));

    // add brush functionality
    brush.on("start brush end", ({ selection }) => {
      let selectedRecords = [];
      if (selection) {
        const [[xmin, ymin], [xmax, ymax]] = selection;
        selectedRecords = dot
          .style("stroke", null)
          .filter((d) => (
            xmin <= xAxis(d.emb_x)
            && xAxis(d.emb_x) < xmax
            && ymin <= yAxis(d.emb_y)
            && yAxis(d.emb_y) < ymax
          ))
          .style("stroke", "gray")
          .data();
      } else {
        dot.style("stroke", null);
      }
      svg.property("value", selectedRecords).dispatch("input");
      if (typeof onBrush === "function") onBrush(selectedRecords);
    });
  }

  render() {
    const { tooltipPosition, activeRecord } = this.state;
    return [
      <svg key="scatterPlot" id="svg" className="svg" />,
      <Tooltip
        key="tooltip"
        style={tooltipPosition}
        activeRecord={activeRecord}
      />,
    ];
  }
}

ScatterPlot.propTypes = {
  data: PropTypes.arrayOf(recordShape).isRequired,
  height: PropTypes.number,
  width: PropTypes.number,
  onActiveRecord: PropTypes.func,
  onActiveRecordLeave: PropTypes.func,
  onClick: PropTypes.func,
  onBrush: PropTypes.func,
};
ScatterPlot.defaultProps = {
  height: 720,
  width: 720,
  onActiveRecord: null,
  onActiveRecordLeave: null,
  onClick: null,
  onBrush: null,
};
