import { AI, AIModel, ChatRole, CreateEmbeddingProps, OpenAIEmbeddingModel, ContentValidator, ChatFunction, Message } from "@cosine/ai"
import { CustomTiktokenModel, estimateFunctionTokens, estimateTokens, TokenizerModel, truncateByNewLine } from "@cosine/common"
import { summarizePrompt } from "./prompts/summarize.prompt"
import { VectorDatabase } from "@cosine/database"
import YAML from "yaml"
import { ChatMessage, CodeChatMessage, ContextChatMessage, FileChatMessage, PromptChatMessage, SearchContext, TextChatMessage, truncateChatMessage } from "."
import { CancellationToken } from "@cosine/cancellation"
import { v4 } from "uuid"

export type ContiguousContextProps = {
  newMessages: ChatMessage[]
  messages: ChatMessage[]
  systemPrompt: ChatMessage
  model: AIModel
  ai: AI
  summarizationModel: AIModel
  summarizationFormat: "prose" | "bullets"
  maxAttempts: number
  cancelToken?: CancellationToken
}

export type DiscontiguousContextProps = {
  newMessages: ChatMessage[]
  messages: ChatMessage[]
  systemPrompt: ChatMessage
  ai: AI
  model: AIModel
  db: VectorDatabase<ChatMessage>
  cancelToken?: CancellationToken
}

export type ForgetfulContextProps = {
  newMessages: ChatMessage[]
  messages: ChatMessage[]
  systemPrompt: ChatMessage
  model: AIModel
  targetTokens?: number
  functions?: ChatFunction[]
  completionTokens?: number
  cancelToken?: CancellationToken
}

export async function collateContiguousContext(props: ContiguousContextProps, attempts = 0): Promise<ChatMessage[]> {
  const { newMessages, messages, systemPrompt, model, ai, summarizationModel, maxAttempts, cancelToken } = props
  if (attempts > maxAttempts) {
    return collateForgetfulContext({ newMessages, messages, systemPrompt, model, cancelToken })
  }
  const totalContextLength = estimateContextLength([systemPrompt, ...messages, ...newMessages], model)
  if (totalContextLength > model.maxTokens * 0.75) {
    const chunks = chunkMessagesForSummary(messages, summarizationModel)
    const rawSummaries = await Promise.all(chunks.map((chunk) => summarize(ai, chunk, summarizationModel)))
    const summaries = rawSummaries.filter((el) => el !== undefined) as ChatMessage[][]
    const newContext = summaries.flat()
    if (summaries.length === 0) {
      return collateForgetfulContext({ newMessages, messages, systemPrompt, model, cancelToken })
    }
    if (estimateContextLength([systemPrompt, ...newContext, ...newMessages], model) > model.maxTokens * 0.75) {
      return collateContiguousContext({ ...props, messages: newContext }, attempts + 1)
    }
    return [...newContext, ...newMessages]
  }
  return [...messages, ...newMessages]
}

async function summarize(ai: AI, messages: ChatMessage[], summarizationModel: AIModel): Promise<ChatMessage[] | undefined> {
  const conversationMessage = new PromptChatMessage(messages.map((el) => `${el.role}: ${el.content}`).join("\n"))
  const starterMessage: ChatMessage = new TextChatMessage(v4(), "- SummaryTranscript:")
  const chatMessages = [summarizePrompt(), conversationMessage, starterMessage]
  const validator: ContentValidator = (response) => {
    try {
      const summaryMessages = YAML.parse(response) as { [key: string]: string }[]
      if (!Array.isArray(summaryMessages)) {
        throw new Error("Summary is not an array")
      }
      if (!summaryMessages.every((el) => Object.keys(el).length === 1)) {
        throw new Error("Summary is not an array of objects")
      }
      if (!summaryMessages.every((el) => Object.keys(el)[0] === "User" || Object.keys(el)[0] === "Assistant")) {
        throw new Error(`Summary is not an array of objects with keys "User" or "Assistant"`)
      }
      return true
    } catch {
      return false
    }
  }
  const summary = await ai.createChat({ model: summarizationModel, messages: chatMessages, validator, availableAttempts: 3 })
  if (!summary) {
    return undefined
  }
  try {
    const summaryMessages = YAML.parse(summary.message.content!) as { [key: string]: string }[]
    return summaryMessages.map((el) => {
      const role = Object.keys(el)[0] === "User" ? ChatRole.User : ChatRole.Assistant
      return { id: v4(), role, content: el[Object.keys(el)[0]] }
    })
  } catch (e) {
    return undefined
  }
}

function chunkMessagesForSummary(messages: ChatMessage[], summarizationModel: AIModel): ChatMessage[][] {
  const chunks: ChatMessage[][] = []
  let currentChunk: ChatMessage[] = []
  for (const message of messages) {
    const currentChunkLength = estimateContextLength(currentChunk, summarizationModel)
    if (currentChunkLength + estimateTokens(message.content, summarizationModel.id as TokenizerModel) > summarizationModel.maxTokens * 0.75) {
      chunks.push(currentChunk)
      currentChunk = []
    }
    currentChunk.push(message)
  }
  if (currentChunk.length > 0) {
    chunks.push(currentChunk)
  }
  return chunks
}

export async function collateDiscontiguousContext(props: DiscontiguousContextProps): Promise<ChatMessage[]> {
  const { newMessages, messages, systemPrompt, db, ai, model, cancelToken } = props
  const embeddingProps: CreateEmbeddingProps = {
    model: OpenAIEmbeddingModel.EMBED_TEXT_EMBEDDING_3_SMALL.id,
    input: newMessages.map((msg) => msg.content!),
  }
  const result = await ai.createEmbedding(embeddingProps)
  if (!result || result.embeddings.length === 0) {
    // TODO: Currently fallback to forgetful implementation, is this a good idea?
    return Promise.resolve(collateForgetfulContext({ newMessages, messages, systemPrompt, model, cancelToken }))
  }
  const totalContextLength = estimateContextLength([systemPrompt, ...messages, ...newMessages], model)
  if (totalContextLength > model.maxTokens * 0.75) {
    // Find the most similar message in the database
    const similarVectors = await db.query(result[0].embedding, 16)
    const similarMessages = similarVectors.map((el) => el.item)
    // If there are no similar messages, just return the new message
    if (similarMessages.length === 0) {
      return Promise.resolve([systemPrompt, ...messages, ...newMessages])
    }
    // Figure out how many of the similar messages we can fit into the context window
    let currentCount: number = estimateContextLength([systemPrompt, ...newMessages], model)
    const fittableMessages: ChatMessage[] = []
    for (const message of similarMessages) {
      const newCount = currentCount + estimateTokens(message.content, model.id as TokenizerModel)
      if (newCount > model.maxTokens * 0.75) {
        break
      }
      fittableMessages.push(message)
      currentCount = newCount
    }
    // Reorder the similar messages to be in the same order as the messages
    const reorderedMessages = fittableMessages.sort((a, b) => messages.indexOf(a) - messages.indexOf(b))
    // Write the new message to the database
    newMessages.forEach((msg, i) => db.write([result[i].embedding], msg))
    // Return the new message and the reordered messages
    return Promise.resolve([...reorderedMessages, ...newMessages])
  }
  // Write the new message to the database
  newMessages.forEach((msg, i) => db.write([result[i].embedding], msg))
  return Promise.resolve([...messages, ...newMessages])
}

export function collateForgetfulContext(props: ForgetfulContextProps): ChatMessage[] {
  const { newMessages, messages, model, systemPrompt, cancelToken } = props
  const context = [systemPrompt, ...messages, ...newMessages]
  let totalContextLength = estimateContextLength(context, model) + estimateFunctionTokens(props.functions || [], model.id as TokenizerModel)
  const maxTokens = props.targetTokens || model.maxTokens
  const desiredLength = props.completionTokens ? maxTokens - props.completionTokens : maxTokens * 0.75
  while (totalContextLength > desiredLength) {
    cancelToken?.throwIfCancelled()

    // If old/new messages are empty, break
    if (messages.length === 0 && newMessages.length === 0) {
      break
    }
    // If there are old messages then remove one
    else if (messages.length > 0) {
      messages.shift()
    }
    // If we are over the limit with new messages then try to reduce them
    else if (newMessages.length > 0) {
      const ratio = Math.max(desiredLength / totalContextLength - 0.1, 0.4)
      const estTokens = estimateTokens(newMessages[0].content, model.id as TokenizerModel)

      // Remove the first messages of the new set if it is small (and therefore not worth reducing)
      if (estTokens < 150) {
        newMessages.shift()
      }
      // Otherwise check the type of message and truncate it
      else if ((newMessages[0] as FileChatMessage).data?.fileContents !== undefined) {
        const longMessage = newMessages[0] as FileChatMessage
        newMessages[0] = new FileChatMessage({
          ...longMessage.data,
          fileContents: truncateByNewLine(longMessage.data.fileContents, ratio),
        })
      } else if (newMessages.some((m) => m instanceof CodeChatMessage)) {
        const codeMessages = newMessages
          .filter((m) => m instanceof CodeChatMessage)
          .sort((a, b) => estimateTokens(b.content!, model.id as TokenizerModel) - estimateTokens(a.content!, model.id as TokenizerModel))
        const longMessage = codeMessages[0] as CodeChatMessage
        const index = newMessages.findIndex((m) => m.content === longMessage.content)
        truncateChatMessage(newMessages[index], ratio)
      } else {
        // Fallback to ensure we don't infinte loop
        // TODO: We might need to handle this case better per message type
        newMessages.shift()
      }
    }

    totalContextLength = estimateContextLength([systemPrompt, ...messages, ...newMessages], model) + estimateFunctionTokens(props.functions || [], model.id as TokenizerModel)
  }
  return [...messages, ...newMessages]
}

export function collateUserContext(contexts: SearchContext[]): [ContextChatMessage] {
  if (contexts.length === 0) {
    return [new ContextChatMessage("No user context found")]
  }
  const messages: ContextChatMessage[] = contexts.map((ctx) => new ContextChatMessage(ctx.getContext()))
  messages.unshift(
    new ContextChatMessage(
      "# Development Environment Information\n The following information is an overview of the current development environment. Use this information to help answer the question if it is relevant.\n",
    ),
  )
  messages.push(new ContextChatMessage("*End of Development Environment Information*\n---"))
  const content = messages.reduce((acc, msg) => acc + msg.content, "")
  return [new ContextChatMessage(content)]
}

export function substituteUserContext(sys: ChatMessage, context: ContextChatMessage[], history: ChatMessage[], prompt: ChatMessage, model: AIModel): ChatMessage[] {
  if (!context) {
    return [sys, ...history, prompt]
  }

  const reservedSys = estimateTokens(sys, model.id as CustomTiktokenModel)
  const reservedContext = estimateContextLength(context, model)
  const reservedLastHistory = estimateTokens(history[history.length - 1], model.id as CustomTiktokenModel)
  const reservedPrompt = estimateTokens(prompt, model.id as CustomTiktokenModel)

  const totalReserved = reservedSys + reservedContext + reservedLastHistory + reservedPrompt + Math.ceil(model.maxTokens * 0.05)
  // If adding in the users context would exceed the max tokens, then just return the history and prompt
  // Assumption here is that the user is most likely continuing their converstion, so history is more important
  if (totalReserved >= model.maxTokens) {
    return [sys, ...history, prompt]
  }

  const reducedHistory = reduceHistory(history, model.maxTokens, totalReserved, model)
  return [sys, ...context, ...reducedHistory, prompt]
}

function reduceHistory(msgs: ChatMessage[], maxTokens: number, reservedTokens: number, model: AIModel): ChatMessage[] {
  if (!msgs) {
    return []
  }
  const tokens = estimateContextLength(msgs, model)
  if (tokens <= maxTokens - reservedTokens) {
    return msgs
  }
  const reducedMsgs = [...msgs]
  reducedMsgs.shift()
  return reduceHistory(reducedMsgs, maxTokens, reservedTokens, model)
}

export function estimateContextLength(messages: Message[], model: AIModel): number {
  return (
    messages.reduce((acc, m) => {
      const messageTokens = 3
      const nameTokens = m.name ? estimateTokens(m.name, model.id as CustomTiktokenModel) : 0
      const contentTokens = estimateTokens(m.content!, model.id as CustomTiktokenModel)
      return acc + messageTokens + nameTokens + contentTokens
    }, 0) + 3
  )
}
