import { useState } from "react";
import { dataset } from "@/data";
import {
  createModel,
  prepareData,
  trainModel,
  manualPredict,
} from "@/model/training";
import * as tf from "@tensorflow/tfjs";
import { ScoreTable } from "@/components/score-table/ScoreTable";
import { PredictionForm } from "@/components/prediction-form/PredictionForm";

export function ModelTraining() {
  const [training, setTraining] = useState(false);
  const [predictions, setPredictions] = useState<number[]>([]);
  const [trainedModel, setTrainedModel] = useState<tf.Sequential | null>(null);
  const [weights, setWeights] = useState<number[][][]>([]);
  const [biases, setBiases] = useState<number[][]>([]);
  const [normalizationInfo, setNormalizationInfo] = useState<any>(null);

  const handleTrainModel = async () => {
    setTraining(true);

    const { inputTensor, labelTensor, normalizationInfo } =
      prepareData(dataset);
    const model = createModel();
    await trainModel(model, inputTensor, labelTensor);

    const extractedWeights: number[][][] = [];
    const extractedBiases: number[][] = [];
    model.getWeights().forEach((tensor, index) => {
      if (index % 2 === 0) {
        extractedWeights.push(tensor.arraySync() as number[][]);
      } else {
        extractedBiases.push(tensor.arraySync() as number[]);
      }
      tensor.dispose();
    });

    const manualPreds = manualPredict(
      dataset,
      extractedWeights,
      extractedBiases,
      normalizationInfo
    );

    setWeights(extractedWeights);
    setBiases(extractedBiases);
    setNormalizationInfo(normalizationInfo);
    setPredictions(manualPreds);
    setTrainedModel(model);

    inputTensor.dispose();
    labelTensor.dispose();
    setTraining(false);
  };

  return (
    <div className="score-app__container">
      <h1 className="score-app__heading">Predicting Target Score</h1>
      {training && <div className="score-app__loading">Training model...</div>}
      {!training && (
        <button className="button" onClick={handleTrainModel}>
          Train Model
        </button>
      )}
      {!training && trainedModel && predictions.length > 0 && (
        <>
          <ScoreTable dataset={dataset} predictions={predictions} />
          <PredictionForm
            weights={weights}
            biases={biases}
            normalizationInfo={normalizationInfo}
          />
        </>
      )}
    </div>
  );
}
