import { Component } from "@cosine/code"
import { CustomTiktokenModel, decodeTokens, encodeToTokens, estimateTokens, truncateByNewLine } from "@cosine/common"

import {
  CustomChatMessage,
  CustomChatMessageType,
  ExplainCodeChatMessage,
  SearchChatMessage,
  FileChatMessage,
  FilesystemChatMessage,
  SearchMessageType,
  CodeChatMessage,
  ChatMessage,
} from "../types"
import { AIModel, ChatRole, Message } from "@cosine/ai"
import { estimateContextLength } from ".."

export function proportionallyTruncateContextWindow(window: Message[], model: AIModel, targetTokens: number): Message[] {
  const systemPrompt: Message | undefined = !!window[0] && window[0].role === ChatRole.System ? window[0] : undefined
  const context = systemPrompt ? window.slice(1) : window
  // Calculate how over the limit the window is
  const fullLength = estimateContextLength(context, model) + (systemPrompt ? estimateTokens(systemPrompt?.content || "", model.id as CustomTiktokenModel) + 3 : 0)
  const overflow = fullLength - targetTokens + 1
  const contextLength = estimateContextLength(context, model)
  if (overflow <= 0) return window
  const messageLengths = context.map((m) => estimateTokens(m.content, model.id as CustomTiktokenModel) + 3 + (m.name ? estimateTokens(m.name, model.id as CustomTiktokenModel) : 0))
  const messageProportions = messageLengths.map((l) => l / contextLength)
  const updated = context.map((m, i) => {
    const targetPrune = Math.ceil(overflow * messageProportions[i])
    const contentTokens = encodeToTokens(m.content!, model.id as CustomTiktokenModel)
    const removalNumber = contentTokens.length - targetPrune - 10 // -10 as we'll add in a '...[TRUNCATED FOR BREVITY]' token at the end of the truncated string
    return { ...m, content: `${decodeTokens(contentTokens.slice(0, removalNumber), model.id as CustomTiktokenModel)}...[TRUNCATED FOR BREVITY]` }
  })
  return systemPrompt ? [systemPrompt, ...updated] : updated
}

export function forgetfullyTruncateContextWindow(window: Message[], model: AIModel, targetTokens: number): Message[] {
  const systemPrompt: Message | undefined = !!window[0] && window[0].role === ChatRole.System ? window[0] : undefined
  const context = systemPrompt ? window.slice(1) : window
  // Calculate how over the limit the window is
  const fullLength = estimateContextLength(context, model) + (systemPrompt ? estimateTokens(systemPrompt?.content || "", model.id as CustomTiktokenModel) + 3 : 0)
  const overflow = fullLength - targetTokens + 1
  if (overflow <= 0) return window
  const messageLengths = context.map((m) => estimateTokens(m.content, model.id as CustomTiktokenModel) + 3 + (m.name ? estimateTokens(m.name, model.id as CustomTiktokenModel) : 0))
  // Find the first index where the sum of the messages up to that point is greater than the overflow
  let runningTotal = 0
  let pruneIndex = 0
  for (let i = 0; i < messageLengths.length; i++) {
    runningTotal += messageLengths[i]
    if (runningTotal > overflow) {
      pruneIndex = i
      break
    }
  }
  const updated = context.slice(pruneIndex)
  return systemPrompt ? [systemPrompt, ...updated] : updated
}

export function truncateChatMessage(msg: ChatMessage, truncateRatio: number) {
  if (truncateRatio < 0 || truncateRatio > 1) {
    throw new Error("Truncate ratio must be between 0 and 1")
  }
  if (!(msg instanceof CustomChatMessage)) {
    return
  } // No real data to truncate

  switch (msg.type) {
    case CustomChatMessageType.ExplainCode:
      const explainMsg = msg as ExplainCodeChatMessage
      explainMsg.data.components = explainMsg.data.components.map((component) => truncateComponent(component, truncateRatio))
      break
    case CustomChatMessageType.Search:
      truncateSearchChatMessage(msg as SearchChatMessage, truncateRatio)
      break
    case CustomChatMessageType.File:
      const fileMsg = msg as FileChatMessage
      fileMsg.data.fileContents = truncateByNewLine(fileMsg.data.fileContents, 0.7)
      break
    case CustomChatMessageType.Filesystem:
      const fsMsg = msg as FilesystemChatMessage
      fsMsg.data = truncateByNewLine(fsMsg.data, 0.7)
      break
    /**
     * Other messages types are not suitable for truncation, because:
     * - generally small
     * - hard to truncated
     * - must be preserved
     */
  }
}

function truncateComponent(component: Component, truncateRatio: number): Component {
  if (truncateRatio < 0 || truncateRatio > 1) {
    throw new Error("Truncate ratio must be between 0 and 1")
  }
  component.text = truncateByNewLine(component.text, truncateRatio)
  return component
}

function truncateSearchChatMessage(msg: SearchChatMessage, truncateRatio: number) {
  if (truncateRatio < 0 || truncateRatio > 1) {
    throw new Error("Truncate ratio must be between 0 and 1")
  }

  switch (msg.searchType) {
    case SearchMessageType.Code:
      const codeMsg = msg as CodeChatMessage
      codeMsg.data[1] = truncateComponent(codeMsg.data[1], truncateRatio)
      break
    /**
     * Other messages types are not suitable for truncation, because:
     * - generally small
     * - hard to truncated
     * - must be preserved
     * - not yet implemented
     */
  }
}
