/* sankey-diagram.ts */

/**
 * @description
 *
 * @param {treatmentNodeCount} number
 */

/* Interfaces */
import { SankeyProps } from "src/shared/interfaces/charts/SankeyDiagram";

export interface AnnotationProps {
  x: number;
  y: number;
  text: string;
  showarrow: boolean;
  font: FontProps;
}

interface FontProps {
  size: number;
  color: string;
}

interface Transitions {
  [key: number]: number[];
}

/**
 * - Compute max depth for each node (recursive depth-first search)
 */
export const computeMaxDepth = (data: SankeyProps) => {
  // all nodes in sankey diagram
  const nodes = data.node.label;
  // transitions are source -> target
  const transitions: Transitions = {};

  data.link.source.forEach((source, i) => {
    const target = data.link.target[i];
    if (!transitions[source]) transitions[source] = [];
    transitions[source].push(target);
  });

  const getDepth = (
    node: number,
    depth = 1,
    visited: Set<number> = new Set()
  ): number => {
    if (visited.has(node)) return depth;
    visited.add(node);

    const nextNodes = transitions[node];
    // there are no more connections
    if (!nextNodes || nextNodes.length === 0) return depth;

    return Math.max(
      ...nextNodes.map((next: number) => getDepth(next, depth + 1, visited))
    );
  };

  const depths: number[] = nodes.map((_, index) => getDepth(index));
  return Math.max(...depths);
};

export const createTreatmentLines = (maxDepth: number): AnnotationProps[] => {
  const annotations: AnnotationProps[] = [];
  const treatmentNodeCount: number = 2 * maxDepth + 1;

  for (let i = 1; i < treatmentNodeCount; i += 2) {
    annotations.push({
      x: i / (treatmentNodeCount - 1),
      y: -0.125,
      text: `TL${i === 1 ? 1 : Math.round(i / 2)}`,
      showarrow: false,
      font: {
        size: 10,
        color: "#91a4bf",
      },
    });
  }
  return annotations;
};

export const createTreatmentLinesDeepestFlow = (deepestFlow: number) => {
  const treatmentLines: AnnotationProps[] = [];

  for (let i = 0; i < deepestFlow; i++) {
    treatmentLines.push({
      // x: i / deepestFlow,
      x: i / (deepestFlow - 1),
      y: -0.125,
      text: `TL${i + 1}`,
      showarrow: false,
      font: {
        size: 10,
        color: "#91a4bf",
      },
    });
  }

  return treatmentLines;
};
