import * as tf from "@tensorflow/tfjs";
import { RowData } from "@/data";

export interface MinMax {
  min: number;
  max: number;
}

export interface NormalizationInfo {
  featureMinMax: MinMax[];
  labelMinMax: MinMax;
}

export function computeMinMax(values: number[]): MinMax {
  const min = Math.min(...values);
  const max = Math.max(...values);
  return { min, max };
}

export function minMaxScale(value: number, { min, max }: MinMax): number {
  if (max === min) return 0;
  return (value - min) / (max - min);
}

export function minMaxUnscale(value: number, { min, max }: MinMax): number {
  return value * (max - min) + min;
}

export function computeClickBonus(click: number): number {
  if (click < 5.0) return click;
  if (click > 10.0) return Math.exp(click + 4.0);
  return Math.exp(click + 2.0);
}

export function computeBouncePenalty(bounces: number): number {
  if (bounces <= 2.0) return bounces;
  if (bounces >= 10.0) return Math.exp(bounces - 2.0);
  return Math.exp(bounces - 1.0);
}

export function computeUnsubscribePenalty(unsubscribes: number): number {
  if (unsubscribes <= 1.0) return unsubscribes;
  if (unsubscribes >= 5.0) Math.exp(unsubscribes - 4.0);
  return Math.exp(unsubscribes - 2.0);
}

export function prepareData(data: RowData[]) {
  const openedVals = data.map((row) => row.opened);
  const clickedVals = data.map((row) => computeClickBonus(row.clicked));
  const bouncesVals = data.map((row) => row.bounces);
  const unsubVals = data.map((row) =>
    computeUnsubscribePenalty(row.unsubscribes)
  );
  const bounceVals = data.map((row) => computeBouncePenalty(row.bounces));

  const labelVals = data.map((row) => row.targetScore);

  const openedMinMax = computeMinMax(openedVals);
  const clickedMinMax = computeMinMax(clickedVals);
  const bouncesMinMax = computeMinMax(bouncesVals);
  const unsubMinMax = computeMinMax(unsubVals);
  const bounceMinMax = computeMinMax(bounceVals);

  const labelMinMax = computeMinMax(labelVals);

  const normalizedInputs = data.map((row) => {
    const normOpened = minMaxScale(row.opened, openedMinMax);
    const normClicked = minMaxScale(row.clicked, clickedMinMax);
    const normBounces = minMaxScale(row.bounces, bouncesMinMax);
    const normUnsubs = minMaxScale(row.unsubscribes, unsubMinMax);
    const rawBounce = computeBouncePenalty(row.bounces);
    const normBounce = minMaxScale(rawBounce, bounceMinMax);

    return [normOpened, normClicked, normBounces, normUnsubs, normBounce];
  });

  const normalizedLabels = labelVals.map((val) => [
    minMaxScale(val, labelMinMax),
  ]);

  const inputTensor = tf.tensor2d(normalizedInputs);
  const labelTensor = tf.tensor2d(normalizedLabels);

  const normalizationInfo: NormalizationInfo = {
    featureMinMax: [
      openedMinMax,
      clickedMinMax,
      bouncesMinMax,
      unsubMinMax,
      bounceMinMax,
    ],
    labelMinMax,
  };

  return { inputTensor, labelTensor, normalizationInfo };
}

export function createModel(): tf.Sequential {
  const model = tf.sequential();

  // Input layer
  model.add(
    tf.layers.dense({ units: 64, activation: "relu", inputShape: [5] })
  );

  // Hidden layers
  model.add(tf.layers.dense({ units: 32, activation: "relu" }));
  model.add(tf.layers.dense({ units: 16, activation: "relu" }));
  model.add(tf.layers.dense({ units: 8, activation: "relu" }));

  // Output layer
  model.add(tf.layers.dense({ units: 1 }));

  model.compile({
    optimizer: tf.train.adam(),
    loss: tf.losses.meanSquaredError,
    metrics: ["mse"],
  });

  return model;
}

export async function trainModel(
  model: tf.Sequential,
  inputTensor: tf.Tensor2D,
  labelTensor: tf.Tensor2D
) {
  const epochs = 500;

  const history = await model.fit(inputTensor, labelTensor, {
    epochs,
    shuffle: true,
    verbose: 0,
  });

  console.log("Final MSE:", history.history.mse?.slice(-1)[0]);
}

export function predict(
  model: tf.Sequential,
  inputTensor: tf.Tensor2D
): tf.Tensor {
  return model.predict(inputTensor) as tf.Tensor;
}

export function unscalePredictions(
  predictions: tf.Tensor,
  labelMinMax: MinMax
): number[] {
  const rawVals = predictions.dataSync();
  return Array.from(rawVals).map((val) => minMaxUnscale(val, labelMinMax));
}

export function relu(x: number): number {
  return Math.max(0, x);
}

export function matMulVec(mat: number[][], vec: number[]): number[] {
  const out = new Array(mat[0].length).fill(0);
  for (let col = 0; col < mat[0].length; col++) {
    let sum = 0;
    for (let row = 0; row < mat.length; row++) {
      sum += mat[row][col] * vec[row];
    }
    out[col] = sum;
  }
  return out;
}

export function addBias(vec: number[], bias: number[]): number[] {
  return vec.map((v, i) => v + bias[i]);
}

export function forwardPass(input: number[], weights: any[]): number {
  const [W1, b1, W2, b2, W3, b3] = weights;

  let layer1 = matMulVec(W1, input);
  layer1 = addBias(layer1, b1);
  layer1 = layer1.map(relu);

  let layer2 = matMulVec(W2, layer1);
  layer2 = addBias(layer2, b2);
  layer2 = layer2.map(relu);

  let output = matMulVec(W3, layer2);
  output = addBias(output, b3);

  return output[0];
}

export function manualPredict(
  data: RowData[],
  weights: number[][][],
  biases: number[][],
  normalizationInfo: NormalizationInfo
): number[] {
  const { featureMinMax, labelMinMax } = normalizationInfo;

  return data.map((row) => {
    const normalizedInput = [
      minMaxScale(row.opened, featureMinMax[0]),
      minMaxScale(computeClickBonus(row.clicked), featureMinMax[1]),
      minMaxScale(row.bounces, featureMinMax[2]),
      minMaxScale(
        computeUnsubscribePenalty(row.unsubscribes),
        featureMinMax[3]
      ),
      minMaxScale(computeBouncePenalty(row.bounces), featureMinMax[4]),
    ];

    let currentLayer = normalizedInput;
    for (let i = 0; i < weights.length; i++) {
      const W = weights[i];
      const b = biases[i];
      currentLayer = matMulVec(W, currentLayer);
      currentLayer = addBias(currentLayer, b);
      if (i < weights.length - 1) currentLayer = currentLayer.map(relu);
    }

    const normalizedPrediction = currentLayer[0];
    const unscaledPrediction = minMaxUnscale(normalizedPrediction, labelMinMax);

    return Math.min(unscaledPrediction, 100);
  });
}
