import { CancellationToken } from "@cosine/cancellation"
import { estimateTokens, TokenizerModel } from "@cosine/common"
import { AIModel } from "../models"
import { splitIntoChunks } from "./chunk.utils"

export function chunkEmbeddings(
  inputs: string[],
  model: AIModel,
  targetRequestByteSize: number,
  cancelToken?: CancellationToken,
): { processedInput: string; originalIndex: number }[][] {
  if (targetRequestByteSize <= 0) {
    throw new Error("Target request byte size must be greater than 0")
  }
  if (!inputs.every((el) => typeof el === "string")) {
    throw new Error("Input must be all strings")
  }

  const tokenBuffer = model.maxTokens / 100
  const chunkedInputs: { processedInput: string; originalIndex: number }[] = []
  for (let i = 0; i < inputs.length; i++) {
    cancelToken?.throwIfCancelled()
    const input = inputs[i]
    const estimatedTokens = estimateTokens(input, model.id as TokenizerModel)
    if (estimatedTokens + tokenBuffer > model.maxTokens) {
      const multiple = Math.ceil(estimatedTokens / model.maxTokens) + 1
      const chunkSize = Math.ceil(input.length / multiple)
      const chunkedInput = splitIntoChunks(model, input, chunkSize)
      chunkedInput.forEach((chunk) => {
        if (estimateTokens(chunk, model.id as TokenizerModel) + tokenBuffer > model.maxTokens) {
          chunkedInputs.push({ processedInput: chunk.slice(0, chunk.length / 2), originalIndex: i })
          chunkedInputs.push({ processedInput: chunk.slice(chunk.length / 2), originalIndex: i })
        } else {
          chunkedInputs.push({ processedInput: chunk, originalIndex: i })
        }
      })
    } else {
      chunkedInputs.push({ processedInput: input, originalIndex: i })
    }
  }

  const chunkedChunkedInputs: { processedInput: string; originalIndex: number }[][] = []
  let currentChunk: { processedInput: string; originalIndex: number }[] = []
  let accSize = 0
  for (const chunk of chunkedInputs) {
    cancelToken?.throwIfCancelled()
    const chunkSize = Buffer.byteLength(chunk.processedInput, "utf8")
    if (accSize + chunkSize > targetRequestByteSize || currentChunk.length >= 2048) {
      chunkedChunkedInputs.push(currentChunk)
      currentChunk = [chunk]
      accSize = chunkSize
    } else {
      currentChunk.push(chunk)
      accSize += chunkSize
    }
  }
  if (currentChunk.length) {
    chunkedChunkedInputs.push(currentChunk)
  }
  return chunkedChunkedInputs
}
