import { Message, ChatResult, ChatResultStream } from "../../types"
import { CompletionResult, CompletionStreamListener } from "../../types/completion/completion.result"
import { AI, ChatFunction, CreateChatProps, CreateCompletionProps, CreateEmbeddingProps, JsonSchema } from "../ai"
import {
  ExtractionPrefs,
  castPartialStreamChatResult,
  combineChatResult,
  combineCompletionResult,
  extractChatResult,
  extractCompletionResult,
  extractStreamChatResult,
  extractStreamCompletionResult,
  handleStream,
} from "./openai.stream"
import { min, cloneDeep } from "lodash"
import { Configuration, CreateEmbeddingRequest, OpenAIApi } from "openai"
import { IncomingMessage } from "http"
import { CustomTiktokenModel, TokenizerModel, estimateTokens, hashStringToInteger, sanitizeJSON, sleep } from "@cosine/common"
import "dotenv/config"
import { EmbeddingResult } from "../../types"
import { AIModel } from "../../models"
import { AIRateLimitException } from "@cosine/ratelimit"
import { OpenAIChatModel, OpenAIEmbeddingModel } from "./openai.models"
import { CancellationToken } from "@cosine/cancellation"
import { chunkEmbeddings } from "../../utils/preprocess.utils"

export type ChunkedEmbeddings = { text: string; index: number }[]

export class OpenAI implements AI {
  private apiKeys: string[] = []
  public rateLimitPolicy: "wait" | "throw" = "throw"

  constructor(apiKeys: string | string[] = []) {
    if (typeof apiKeys === "string") {
      this.apiKeys = apiKeys.split(",")
    } else {
      this.apiKeys = apiKeys as string[]
    }

    if (apiKeys.length === 0) {
      this.apiKeys = process.env.OPENAI_API_KEY?.split(",") ?? []
    }
    if (this.apiKeys.length === 0) {
      throw new Error("No API keys found")
    }
  }

  private getAPIKey(): string {
    if (this.apiKeys.length === 0) {
      throw new Error("No API keys found")
    }
    return this.apiKeys[0]
  }

  private rolloverAPIKey(): void {
    if (this.apiKeys.length <= 1) {
      return
    }
    const key = this.apiKeys.shift()
    this.apiKeys.push(key!)
  }

  private getClient(): OpenAIApi {
    const apiKey = this.getAPIKey()
    const organization = process.env.OPENAI_ORGANIZATION_ID
    const configuration = new Configuration({ apiKey, organization })
    return new OpenAIApi(configuration)
  }

  // Handles API errors, returns true if a retry is needed
  private async handleError(error: any, model: string, props?: CreateEmbeddingProps | CreateChatProps): Promise<boolean> {
    const dnsLookupError = error.code === "ENOTFOUND"
    const addressNotAvailableError = error.code === "EADDRNOTAVAIL"
    const connectionResetError = error.code === "ECONNRESET"
    const pipeError = error.code === "EPIPE"
    const dnsLookupTimedOutError = error.code === "EAI_AGAIN"
    const connectionTimedOutError = error.code === "ETIMEDOUT"
    const connectionRefusedError = error.code === "ECONNREFUSED"
    const noRouteToHostError = error.code === "EHOSTUNREACH"
    const connectionAbortedError = error.code === "ECONNABORTED"
    const sslError = error.code === "EPROTO"
    const aborted = error.code === "ERR_REQUEST_ABORTED"
    const isRateLimited = error.status === 429 || error.response?.status === 429
    const isTimedOut = error.status === 408 || error.response?.status === 408 || connectionTimedOutError || dnsLookupTimedOutError
    const is404 = error.status === 404 || error.response?.status === 404
    const isServerError = error.status >= 500 || error.response?.status >= 500
    // const isBadRequest = error.status === 400 || error.response?.status === 400
    if (dnsLookupError || connectionResetError || addressNotAvailableError) {
      await this.onDNSError()
    } else if (isRateLimited) {
      this.rolloverAPIKey()
      if (this.rateLimitPolicy === "throw") {
        throw new AIRateLimitException(model)
      } else {
        await this.onRateLimit()
      }
    } else if (isTimedOut) {
      await this.onTimeout()
    } else if (is404) {
      await this.on404()
    } else if (aborted) {
      await this.onTimeout()
    } else if (isServerError) {
      await this.onServerError()
    } else if (pipeError || connectionRefusedError || noRouteToHostError || connectionAbortedError || sslError) {
      await this.onTimeout()
    } else {
      // Log specific details from axis error
      if (error.response.status) {
        // eslint-disable-next-line no-console
        console.log("[Error] Status:", error.response.status)
      }
      // eslint-disable-next-line no-console
      if (error.response.data) {
        const getCircularReplacer = () => {
          const seen = new WeakSet()
          return (_key: string, value: any) => {
            if (typeof value === "object" && value !== null) {
              if (seen.has(value)) {
                return
              }
              seen.add(value)
            }
            return value
          }
        }
        // eslint-disable-next-line no-console
        console.log("[Error] Response:", JSON.stringify(error.response.data, getCircularReplacer(), 2), "Props", JSON.stringify(props, getCircularReplacer(), 2))
      }

      // Log extra data about embeddings props
      if (props !== undefined && "input" in props) {
        this.logEmbeddingErrorDetails(props)
      }

      throw error
    }
    return true
  }

  private logEmbeddingErrorDetails(props: CreateEmbeddingProps) {
    // eslint-disable-next-line no-console
    console.log("[Embeddings Info] Input length: ", props.input.length)
    for (let index = 0; index < props.input.length; index++) {
      const obj = props.input[index]
      // eslint-disable-next-line no-console
      console.log(`[Embeddings Info][${index}] Input: ${obj}`)

      try {
        // eslint-disable-next-line no-console
        console.log(`[Embeddings Info][${index}] Input: ${estimateTokens(obj, OpenAIEmbeddingModel.EMBED_TEXT_EMBEDDING_3_SMALL.id as TokenizerModel)}`)
      } catch (err) {
        // eslint-disable-next-line no-console
        console.log(`[Embeddings Info][${index}] Failed to estimate tokens:`, JSON.stringify(err, null, 2))
      }
    }
  }

  public async createChat(props: CreateChatProps, cancelToken?: CancellationToken, streamCallback?: ChatResultStream): Promise<ChatResult | undefined> {
    const {
      availableAttempts,
      includeStop,
      completion_prefix,
      available_network_attempts,
      max_tokens,
      stream,
      logit_bias,
      frequency_penalty,
      presence_penalty,
      stop,
      temperature,
      top_p,
      functions,
      function_call,
      tools,
      tool_choice,
      shrink,
      model,
      response_format,
      seed,
      deterministic,
    } = props
    const messages: Message[] = shrink ? fitChatInTokens(props.messages, model, tools?.map((t) => t.function) ?? functions ?? null, max_tokens ?? null) : props.messages
    if (availableAttempts && availableAttempts <= 0) {
      return undefined
    }
    if (available_network_attempts && available_network_attempts <= 0) {
      return undefined
    }
    const messagesHash = deterministic ? hashStringToInteger(JSON.stringify(messages)) : undefined
    const temp = deterministic ? 0 : temperature ?? 0
    const api = this.getClient()
    const req = {
      model: model.id,
      messages: messages
        .filter(({ content }) => !!content) // Filter out empty messages to prevent 400s
        .map(({ content, role, name }) => ({
          content,
          role,
          name,
        })),
      max_tokens,
      logit_bias,
      frequency_penalty,
      presence_penalty,
      stop,
      temperature: temp,
      stream,
      top_p,
      functions,
      function_call,
      tools,
      tool_choice,
      response_format,
      logprobs: true,
      seed: deterministic ? messagesHash : seed,
    }

    const signal = this.getAbortSignal(cancelToken)
    const response = await api.createChatCompletion(req, { responseType: stream ? "stream" : undefined, signal: signal }).catch((err) => {
      return err
    })
    cancelToken?.throwIfCancelled()
    if (response instanceof Error) {
      const retry = await this.handleError(response, model.id, props)
      if (retry) {
        return this.createChat(
          {
            ...props,
            available_network_attempts: (available_network_attempts || 2) - 1,
          },
          cancelToken,
          streamCallback,
        )
      }
    }
    let extractionPrefs: ExtractionPrefs | undefined = undefined
    if (props.stop && includeStop) {
      extractionPrefs = { includeStop, stop: props.stop }
    }
    if (completion_prefix) {
      extractionPrefs = extractionPrefs ? { ...extractionPrefs, prefix: completion_prefix } : { includeStop: false, stop: [], prefix: completion_prefix }
    }

    let result: ChatResult | undefined
    if (stream && streamCallback) {
      return await handleStream(response.data as IncomingMessage, model.id, extractStreamChatResult, combineChatResult, streamCallback).then((result) =>
        castPartialStreamChatResult(result),
      )
    } else {
      result = extractChatResult(response.data, extractionPrefs)
    }

    // A function that can check the validity of a result
    const isValid = (request: CreateChatProps, result: ChatResult): boolean => {
      const validData: boolean = !!result.message.content || !!result.function_call || !!result.message.tool_calls
      const validContent = validateContent(request, result)
      const validFunctions = validateFunctions(request, result)
      return validContent && validFunctions && validData
    }

    // Sometimes the result can contain the function_call in the message.content field,
    // rather than where it should be. So lets check if the content is valid JSON
    const contentAsJson = sanitizeJSON(result?.message.content ?? "")
    if (result && contentAsJson && !result?.function_call) {
      if (isContentIsValidFunctionCall(result, contentAsJson, (clone) => isValid(props, clone))) {
        // Check if the content is a valid function call
        result = replaceFunctionCall(result, contentAsJson)
      } else if (isContentIsValidFunctionCall(result, contentAsJson.function_call, (clone) => isValid(props, clone))) {
        // Check if the content.function_call is a valid function call
        result = replaceFunctionCall(result, contentAsJson.function_call)
      }
      // If the parsed JSON is not a valid function call, then do nothing
    }
    if (result && availableAttempts && !isValid(props, result)) {
      const newTemp = min([temperature! + 0.3, 1])
      let retryModel = model
      let retryMsgs = props.messages
      // On our last attempt, switch the model that is being used
      if (availableAttempts === 1 && !props.response_format) {
        if (model.id === OpenAIChatModel.CHAT_GPT_4_0613.id || model.id === OpenAIChatModel.CHAT_GPT_4_32K_0613.id) {
          retryModel = OpenAIChatModel.CHAT_GPT_TURBO_16K_0613
        } else {
          retryModel = OpenAIChatModel.CHAT_GPT_4_0613
        }
        retryMsgs = fitChatInTokens(props.messages, retryModel, tools?.map((t) => t.function) ?? functions ?? null, max_tokens ?? null)
      }
      return await this.createChat(
        {
          ...props,
          availableAttempts: availableAttempts - 1,
          temperature: newTemp,
          deterministic: undefined,
          model: retryModel,
          messages: retryMsgs,
        },
        cancelToken,
        streamCallback,
      )
    } else if (result && !isValid(props, result)) {
      // If the request is still not valid, then return a basic repsonse
      result.message.content = "I am unable to process your last request. Try rephrasing."
      result.function_call = undefined
    }
    return result
  }

  public async createCompletion(props: CreateCompletionProps, cancelToken?: CancellationToken, streamCallback?: CompletionStreamListener): Promise<CompletionResult | undefined> {
    const { stream, temperature } = props
    const { model, validator, availableAttempts, includeStop, completion_prefix, available_network_attempts, deterministic, seed, prompt, ...rest } = props
    if (availableAttempts && availableAttempts <= 0) {
      return undefined
    }
    if (available_network_attempts && available_network_attempts <= 0) {
      return undefined
    }
    const messagesHash = deterministic ? hashStringToInteger(prompt) : undefined
    const temp = deterministic ? 0 : temperature ?? 0
    const api = this.getClient()
    const signal = this.getAbortSignal(cancelToken)
    const apiProps = { model, prompt, ...rest, temperature: temp, seed: deterministic ? messagesHash : seed }
    const response = await api.createCompletion(apiProps, { responseType: stream ? "stream" : undefined, signal }).catch(async (err) => {
      return err
    })

    cancelToken?.throwIfCancelled()
    if (response instanceof Error) {
      const retry = await this.handleError(response, model)
      if (retry) {
        return this.createCompletion(
          {
            ...props,
            available_network_attempts: (available_network_attempts || 2) - 1,
          },
          cancelToken,
          streamCallback,
        )
      }
    }
    let extractionPrefs: ExtractionPrefs | undefined = undefined
    if (props.stop && includeStop) {
      extractionPrefs = { includeStop, stop: props.stop }
    }
    if (completion_prefix) {
      extractionPrefs = extractionPrefs ? { ...extractionPrefs, prefix: completion_prefix } : { includeStop: false, stop: [], prefix: completion_prefix }
    }
    if (stream && streamCallback) {
      return await handleStream(response.data as IncomingMessage, model, extractStreamCompletionResult, combineCompletionResult, streamCallback)
    }
    const result = extractCompletionResult(response.data, extractionPrefs)
    if (result && availableAttempts && validator && !validator(result.text)) {
      const newTemp = min([temperature! + 0.3, 1])
      return await this.createCompletion(
        {
          ...props,
          availableAttempts: availableAttempts - 1,
          temperature: newTemp,
        },
        cancelToken,
        streamCallback,
      )
    }
    return result
  }

  public async createEmbedding(props: CreateEmbeddingProps, cancelToken?: CancellationToken): Promise<EmbeddingResult | undefined> {
    const { input, model, available_network_attempts } = props
    const allStrings = (input as any[]).every((el) => typeof el === "string")
    const allObjects = (input as any[]).every((el) => typeof el === "object")
    if (!allStrings && !allObjects) {
      throw new Error("Input must be all strings or all objects")
    }
    if (available_network_attempts && available_network_attempts <= 0) {
      return undefined
    }
    const api = this.getClient()
    const strings: string[] = input.map((el: string | { text: string; reference: number }) => (typeof el === "string" ? (el as string) : el.text))
    const embeddingProps: CreateEmbeddingRequest = {
      model: model,
      input: strings,
    }

    const signal = this.getAbortSignal(cancelToken)
    const response = await api.createEmbedding(embeddingProps, { signal }).catch((err) => {
      return err
    })

    cancelToken?.throwIfCancelled()
    if (response instanceof Error) {
      const retry = await this.handleError(response, model, props)
      if (retry) {
        return this.createEmbedding(
          {
            ...props,
            available_network_attempts: (available_network_attempts || 2) - 1,
          },
          cancelToken,
        )
      }
    }
    const data = response.data
    const allEmbeddings = data.data
    // Check for null embeddings
    const nullEmbeddings = allEmbeddings.filter((el: any) => el === null || el.embedding === null)
    if (nullEmbeddings.length) {
      const nullEmbeddingIndices = nullEmbeddings.map((el: any) => el.index)
      const nullEmbeddingInputs = (input as { text: string; reference: number }[]).filter((el, i) => nullEmbeddingIndices.includes(i))
      const reembedded = await api.createEmbedding({
        model: model,
        input: nullEmbeddingInputs.map((el) => el.text),
      })
      if (reembedded?.data && reembedded.data.data.length && reembedded.data.data.length === nullEmbeddings.length) {
        const reembeddedEmbeddings = reembedded.data.data
        for (let i = 0; i < reembeddedEmbeddings.length; i++) {
          const reembeddedEmbedding = reembeddedEmbeddings[i]
          const nullEmbeddingIndex = nullEmbeddingIndices[i]
          allEmbeddings[nullEmbeddingIndex] = reembeddedEmbedding
        }
      }
    }
    // Embeddings
    const embeddings = allEmbeddings.map((el, i) => {
      return {
        index: el.index,
        embedding: el.embedding,
        reference: allStrings ? undefined : (input as { text: string; reference: number }[])[i].reference,
      }
    })
    return {
      embeddings,
      model: data.model,
      usage: data.usage,
    }
  }

  public async preprocessEmbeddings(
    inputs: string[],
    model: AIModel,
    targetRequestByteSize: number,
    cancelToken?: CancellationToken,
  ): Promise<{ processedInput: string; originalIndex: number }[][]> {
    return chunkEmbeddings(inputs, model, targetRequestByteSize, cancelToken)
  }

  private getAbortSignal(cancelToken?: CancellationToken): AbortSignal {
    const abortController = new AbortController()
    const { signal } = abortController
    const onCancel = () => abortController.abort()
    cancelToken?.addOnCancelledCallback(onCancel)
    return signal
  }

  private async onDNSError() {
    await sleep(7_500)
  }

  private async onTimeout() {
    await sleep(1_000)
  }

  private async on404() {
    await sleep(500)
  }

  private async onServerError() {
    await sleep(7_500)
  }

  private async onRateLimit() {
    await sleep(60_000)
  }
}

export const fitCompletionsInTokens = (completions: string[], model: AIModel, functions: ChatFunction[] | null, completionTokens: number | null) =>
  reduceByChronology(completions, (c) => c, model, functions, completionTokens)

export const fitChatInTokens = (messages: Message[], model: AIModel, functions: ChatFunction[] | null, completionTokens: number | null) =>
  reduceByChronology(messages, (m) => m.content, model, functions, completionTokens)

// eslint-disable-next-line @typescript-eslint/no-unused-vars
const reduceByLargest = <T>(messages: T[], extractor: (message: T) => string | undefined, model: AIModel, functions: ChatFunction[] | null, completionTokens: number | null) => {
  const estimateModel = model.id as CustomTiktokenModel
  // Estimate the number of tokens in the request
  const tokens = messages.reduce((acc, message) => acc + estimateTokens(extractor(message), estimateModel), 0)
  const funcTokens = estimateTokens(functions ?? "", estimateModel) + 65 // We underestimate the function tokens, so add a buffer
  const reservedTokens = completionTokens ?? Math.ceil((model.maxTokens * 3) / 100)

  // If the number of tokens is less than the max tokens, return the messages
  if (tokens + funcTokens + reservedTokens <= model.maxTokens) {
    return messages
  }
  // Otherwise, sort by tokens
  const sorted = messages.map((m, idx) => [m, idx] as [T, number]).sort(([a], [b]) => estimateTokens(extractor(a), estimateModel) - estimateTokens(extractor(b), estimateModel))
  // Remove the last message
  const [[, last]] = sorted.slice(-1)
  messages.splice(last, 1)
  // Recurse
  return reduceByLargest(messages, extractor, model, functions, completionTokens)
}

const reduceByChronology = <T>(messages: T[], extractor: (message: T) => string | undefined, model: AIModel, functions: ChatFunction[] | null, completionTokens: number | null) => {
  const estimateModel = model.id as CustomTiktokenModel
  // Estimate the number of tokens in the request
  const tokens = messages.reduce((acc, message) => acc + estimateTokens(extractor(message), estimateModel), 0)
  const funcTokens = estimateTokens(functions ?? "", estimateModel) + 65 // We underestimate the function tokens, so add a buffer
  const reservedTokens = completionTokens ?? Math.ceil((model.maxTokens * 3) / 100)

  // If the number of tokens is less than the max tokens, return the messages
  if (tokens + funcTokens + reservedTokens <= model.maxTokens) {
    return messages
  }

  if (messages.length <= 1) {
    throw new Error("Cannot reduce message further")
  }

  // Be sure to maintain the system prompt message, which is the first one
  const msgs = messages.length > 2 ? [messages[0], ...messages.slice(2)] : [messages[0]]

  return reduceByChronology(msgs, extractor, model, functions, completionTokens)
}

function isContentIsValidFunctionCall(result: ChatResult, functionCall: any, isValid: (clone: ChatResult) => boolean): boolean {
  // Basic check if function call is valid
  if (!functionCall) {
    return false
  }
  if (!functionCall.name) {
    return false
  }

  // Create a clone of the result with different function call, and then test it
  const clone = replaceFunctionCall(result, functionCall)
  return isValid(clone)
}

function replaceFunctionCall(result: ChatResult, functionCall: any): ChatResult {
  // In some instance the function name is prepended by "functions." - remove this
  if (functionCall.name && functionCall.name.startsWith("functions.")) {
    functionCall.name = functionCall.name.replace("functions.", "")
  }

  // Clone the result and replace the function call the new one
  const clone = cloneDeep(result)
  clone.function_call = functionCall
  clone.message.content = undefined
  return clone
}

const validateContent = (request: CreateChatProps, response: ChatResult): boolean => {
  if (!request.validator) {
    return true
  }
  return request.validator(response.message.content ?? "")
}

const validateFunctions = (request: CreateChatProps, response: ChatResult): boolean => {
  if (request.functionValidator && !request.functionValidator(response.function_call)) {
    return false
  }
  // If no functions provided but trying to call one, return false
  if (!request.functions && response.function_call) {
    return false
  }
  // If there are no functions, return true
  if (!request.functions) {
    return true
  }
  // If function specified as none, but gives function anyway, return false
  if (request.function_call === "none" && response.function_call) {
    return false
  }
  // If function specified as none, there must be no function, so return true
  if (request.function_call === "none") {
    return true
  }
  // If function specified as auto, but gives no function, return true
  if (request.function_call === "auto" && !response.function_call) {
    return true
  }
  // If function not specified, and no function given, return true
  if (!request.function_call && !response.function_call) {
    return true
  }
  // If no function given, return false
  if (!response.function_call) {
    return false
  }
  // If name not in specified functions list, return false
  const name = response.function_call.name
  const nameValid = request.functions.map((f) => f.name).includes(name)
  if (!nameValid) {
    return false
  }
  // If function specified and name specified, but name doesn't match, return false
  if (request.function_call && request.function_call !== "auto" && request.function_call.name !== response.function_call.name) {
    return false
  }
  const requestFunction = request.functions.find((f) => f.name === name)
  // If request function not found, return false
  if (!requestFunction) {
    return false
  }
  const responseArgs = response.function_call.arguments
  const parameters = requestFunction.parameters
  if (!validateRequired(parameters, responseArgs)) {
    return false
  }
  if (!parameters.properties) {
    return true
  }
  const properties = parameters.properties
  if (!Object.keys(responseArgs).every((key) => properties[key])) {
    return false
  }
  return Object.entries(responseArgs)
    .map(([key, value]) => {
      const expected = properties[key].type
      if (expected === "array" && Array.isArray(value)) {
        return true
      }
      return expected === typeof value
    })
    .every((arg) => arg)
}

export function validateRequired(schema: JsonSchema, obj: any): boolean {
  if (schema.type === "array" && schema.items) {
    if (!Array.isArray(obj)) {
      return false
    }
    // For every items in the object check that it matches a schema
    for (const item of obj) {
      // Schema.items can be an array of schemas
      if (Array.isArray(schema.items)) {
        let isValid = false

        // Check if the item matches any of the schemas
        for (const itemSchema of schema.items) {
          if (validateRequired(itemSchema, item)) {
            isValid = true
            break
          }
        }

        // If the item doesn't match any of the schemas, the object is invalid
        if (!isValid) {
          return false
        }
      } else {
        // Otherwise `items` defines just one schema
        if (!validateRequired(schema.items, item)) {
          return false
        }
      }
    }
  } else if (schema.type === "object" && typeof obj === "object" && obj !== null) {
    // Check that obj contains all required properties
    if (schema.required) {
      for (const prop of schema.required) {
        if (!(prop in obj)) {
          return false
        }
      }
    }

    // Recusively validate properties
    if (schema.properties) {
      for (const prop in schema.properties) {
        if (prop in obj && typeof obj[prop] === "object" && obj[prop] !== null) {
          if (!validateRequired(schema.properties[prop], obj[prop])) {
            return false
          }
        }
      }
    }
  } else if (schema.type !== typeof obj) {
    // If schema type does not match the object type then its not valid
    return false
  }

  return true
}
