import { Matrix } from 'ml-matrix'

class MatrixUtils {
  matrixMultiplication(itemVectorMatrix, stateVector, weight) {
    return (
      itemVectorMatrix
        .mmul(stateVector)
        // .add(0.5)
        .abs()
        .pow(1.3)
        .mul(weight)
    )
  }

  argSort(arr) {
    return arr
      .map((v, i) => [v, i])
      .sort()
      .map((i) => i[1])
  }

  calculateItemMatchResultsByMediumCategories(
    itemVectorsByMediumCategories,
    stateVectorsByMediumCategories,
    weightsByMediumCategories,
    itemIdsByMediumCategories,
    categoryMappingsByMediumCategories,
  ) {
    const scoresByMediumCategories = {}
    const totalRemainingItemIds = []
    const totalRemainingItemScores = []
    const totalRemainingItemCategoryMappings = []
    Object.keys(weightsByMediumCategories).forEach((mediumCategory) => {
      if (itemVectorsByMediumCategories[mediumCategory]) {
        const itemVectorMatrix = new Matrix(
          itemVectorsByMediumCategories[mediumCategory],
        )
        const selectedStateVector = Matrix.columnVector(
          stateVectorsByMediumCategories[mediumCategory],
        )
        const randomMatrix = Matrix.rand(selectedStateVector.rows, 1)
          .sub(0.5)
          .div((5 * 15) / selectedStateVector.rows)
        const tempScores = this.matrixMultiplication(
          itemVectorMatrix,
          selectedStateVector.add(randomMatrix),
          weightsByMediumCategories[mediumCategory],
        ).to1DArray()
        scoresByMediumCategories[mediumCategory] = tempScores
        totalRemainingItemIds.push(...itemIdsByMediumCategories[mediumCategory])
        totalRemainingItemScores.push(...tempScores)
        totalRemainingItemCategoryMappings.push(
          ...categoryMappingsByMediumCategories[mediumCategory].map(
            (categoryMapping, idx) => ({ ...categoryMapping, idx }),
          ),
        )
      } else {
        console.log(`medium category ${mediumCategory} is all used up..`)
      }
    })
    const totalItemOrdering = this.argSort(totalRemainingItemScores)
    // console.log(`sorted item ids: ${this.argSort(totalRemainingItemIds)}`)

    return [
      scoresByMediumCategories,
      totalRemainingItemIds,
      // totalRemainingItemCategoryMappings,
      totalItemOrdering.map(
        (itemIdx) => totalRemainingItemCategoryMappings[itemIdx],
      ),
      totalItemOrdering,
    ]
    // console.log(`scores by medium categories: ${scoresByMediumCategories[0]}`)
  }

  recalculateItemMatchResultsByMediumCategories(
    itemVectorsByMediumCategories,
    stateVectorsByMediumCategories,
    weightsByMediumCategories,
    itemIdsByMediumCategories,
    categoryMappingsByMediumCategories,
    scoresByMediumCategories,
    mediumCategoryId,
  ) {
    // const scoresByMediumCategories = {}
    const totalRemainingItemIds = []
    const totalRemainingItemScores = []
    const totalRemainingItemCategoryMappings = []
    Object.keys(weightsByMediumCategories).forEach((mediumCategory) => {
      if (itemVectorsByMediumCategories[mediumCategory]) {
        // console.log(
        //   `medium category id: ${typeof mediumCategoryId} ${typeof mediumCategory}`,
        // )
        if (mediumCategoryId === mediumCategory) {
          const itemVectorMatrix = new Matrix(
            itemVectorsByMediumCategories[mediumCategory],
          )
          const selectedStateVector = Matrix.columnVector(
            stateVectorsByMediumCategories[mediumCategory],
          )
          const randomMatrix = Matrix.rand(selectedStateVector.rows, 1)
            .sub((5 * 15) / selectedStateVector.rows)
            .div(5)
          const tempScores = this.matrixMultiplication(
            itemVectorMatrix,
            selectedStateVector.add(randomMatrix),
            weightsByMediumCategories[mediumCategory],
          ).to1DArray()
          scoresByMediumCategories[mediumCategory] = tempScores
        }
        // const itemVectorMatrix = new Matrix(
        //   itemVectorsByMediumCategories[mediumCategory],
        // )
        // const selectedStateVector = Matrix.columnVector(
        //   stateVectorsByMediumCategories[mediumCategory],
        // )
        // const randomMatrix = Matrix.rand(selectedStateVector.rows, 1)
        //   .sub(0.5)
        //   .div(5)
        // const tempScores = this.matrixMultiplication(
        //   itemVectorMatrix,
        //   selectedStateVector.add(randomMatrix),
        //   weightsByMediumCategories[mediumCategory],
        // ).to1DArray()
        // scoresByMediumCategories[mediumCategory] = tempScores
        totalRemainingItemIds.push(...itemIdsByMediumCategories[mediumCategory])
        totalRemainingItemScores.push(
          ...scoresByMediumCategories[mediumCategory],
        )
        totalRemainingItemCategoryMappings.push(
          ...categoryMappingsByMediumCategories[mediumCategory].map(
            (categoryMapping, idx) => ({ ...categoryMapping, idx }),
          ),
        )
      } else {
        console.log(`medium category ${mediumCategory} is all used up..`)
      }
    })
    const totalItemOrdering = this.argSort(totalRemainingItemScores)
    // console.log(`sorted item ids: ${this.argSort(totalRemainingItemIds)}`)

    return [
      scoresByMediumCategories,
      totalRemainingItemIds,
      // totalRemainingItemCategoryMappings,
      totalItemOrdering.map(
        (itemIdx) => totalRemainingItemCategoryMappings[itemIdx],
      ),
      totalItemOrdering,
    ]
    // console.log(`scores by medium categories: ${scoresByMediumCategories[0]}`)
  }

  calculateItemMatchResults(itemVectors, stateVectors, weightsObj) {
    console.log(`about to create matrix instance`)
    // const matchResults = Matrix.zeros(itemVectors.length, 1)
    const itemVectorMatrix = new Matrix(itemVectors)
    console.log(`created matrix instance`)
    const selectedStateVectors = new Matrix(
      Object.keys(weightsObj).map(
        (mediumCategory) => stateVectors[mediumCategory],
      ),
    ).transpose()
    const weightsVector = Matrix.columnVector(
      Object.keys(weightsObj).map(
        (mediumCategory) => weightsObj[mediumCategory],
      ),
    )
    const randomMatrix = Matrix.rand(
      selectedStateVectors.rows,
      selectedStateVectors.columns,
    )
    console.log(
      `selectedStateVectors: ${selectedStateVectors.rows}, ${selectedStateVectors.columns}`,
    )
    return this.argSort(
      itemVectorMatrix
        .mmul(selectedStateVectors.add(randomMatrix))
        .pow(2)
        .mmul(weightsVector)
        .to1DArray(),
    )
  }

  calculateItemMatchResultsForSmallCategories(
    itemVectors,
    stateVector,
    // isRelatedItems = false,
  ) {
    console.log(
      `calculate item match for small cat initiated, itemvectors length: ${itemVectors.length}`,
    )
    const itemVectorMatrix = new Matrix(itemVectors)
    // const randomMatrix = isRelatedItems
    //   ? Matrix.zeros(stateVector.length, 1)
    //   : Matrix.rand(stateVector.length, 1).sub(0.5).div(5)
    const randomMatrix = Matrix.rand(stateVector.length, 1)
      .sub(0.5)
      .div((5 * 15) / stateVector.length)
    const matchresult = itemVectorMatrix
      .mmul(Matrix.columnVector(stateVector).add(randomMatrix))
      .abs()
    return this.argSort(matchresult.to1DArray())
  }

  calculateNewStateVector(stateVector, itemVector) {
    const stateVectorMatrix = Matrix.columnVector(stateVector)
    const itemVectorMatrix = Matrix.columnVector(itemVector)
    stateVectorMatrix.add(itemVectorMatrix)
    return stateVectorMatrix
      .div(stateVectorMatrix.norm())
      .to1DArray()
      .map((element) => Number.parseFloat(element.toPrecision(2)))
  }

  shuffleArray(array) {
    for (let i = array.length - 1; i > 0; i -= 1) {
      const j = Math.floor(Math.random() * (i + 1))
      const temp = array[i]
      array[i] = array[j]
      array[j] = temp
    }
    return array
  }
}

export default new MatrixUtils()
