import { RateLimitData, RateLimitException, RateLimitHandler, RateLimitMetric } from "@cosine/ratelimit"
import { Telemetry } from "@cosine/telemetry"
import { CancellationToken, CancelledException } from "@cosine/cancellation"
import fetch, { RequestInit } from "node-fetch"

export type Headers = Record<string, string> | object

export interface NetworkProps {
  url: string
  headers?: Headers
  agent: string
  version: string
  rateLimitHandler?: RateLimitHandler
}

export abstract class Network {
  protected baseURL: string
  protected headers: Headers
  protected agent: string
  protected version: string
  protected rateLimitHandler?: RateLimitHandler

  constructor(props: NetworkProps) {
    const { url, headers, agent, version } = props
    this.baseURL = url
    this.headers = headers ?? {}
    this.agent = agent
    this.version = version
  }

  public async fetch<T>(url: string, request: RequestInit, cancellationToken?: CancellationToken, stream = false, attempts = 0): Promise<T> {
    const fullURL = new URL(`${this.baseURL}${url}`)
    const authorization = await this.getToken().then((token) => (token ? `Bearer ${token}` : undefined))
    if (authorization) {
      request = {
        ...request,
        headers: { ...request.headers, Authorization: authorization },
      }
    }

    // Link network abort controller to cancellation token
    const abortController = new AbortController()
    const { signal } = abortController
    const onCancel = () => abortController.abort()
    cancellationToken?.addOnCancelledCallback(onCancel)

    request = {
      ...request,
      headers: {
        ...request.headers,
        "User-Agent": this.agent,
        "X-App-Version": this.version,
        ...this.headers,
      },
      signal,
    }
    cancellationToken?.throwIfCancelled()
    const response = await fetch(fullURL, request).catch((err) => err)
    if (!stream) {
      cancellationToken?.removeOnCancelledCallback(onCancel)
    }

    // Abort errors are caused by cancellation tokens, so convert the error
    // or rethrow if already a cancellation error (prevents it getting swallowed into a `new Error()`)
    if (response.name === "AbortError" || response instanceof CancelledException) {
      throw new CancelledException(cancellationToken?.uuid ?? null)
    }
    if (fetchErrors[response.code] && attempts < 3) {
      await onFetchError(response)
      return await this.fetch(url, request, cancellationToken, stream, attempts + 1)
    } else if (response.status === 401 && attempts === 0) {
      await this.reauthenticate()
      return await this.fetch(url, request, cancellationToken, stream, attempts + 1)
    } else if (response.status === 429) {
      await this.onRateLimit(response)
      return await this.fetch(url, request, cancellationToken, stream, attempts)
    } else if (response.status == 504 && attempts === 0) {
      return await this.fetch(url, request, cancellationToken, stream, attempts + 1)
    } else if (Object.values(RateLimitEndpoint).includes(url as RateLimitEndpoint)) {
      // If the response isn't specifically a 429, and we hit a known rate limited endpoint
      // then we can assume we are no longer rate limited
      this.rateLimitHandler?.handleRateLimit(false)
    }
    if (response instanceof Error && attempts === 0) {
      return await this.fetch(url, request, cancellationToken, stream, attempts + 1)
    } else if (response instanceof Error) {
      Telemetry.captureException(response)
      throw response
    }
    const contentType = response.headers?.get("content-type")
    if (!response.ok) {
      const getError = () => {
        if (contentType?.includes("application/json")) {
          return response.json().then((r) => JSON.stringify(r))
        }
        if (contentType?.includes("text/plain")) {
          return response.text()
        }
        return response.code ?? response.statusText ?? "Unknown Error"
      }
      const error = await getError()
      Telemetry.addBreadcrumb({ message: `Network Error: ${url}`, data: { url, attempts, error: error.message }, type: "error", level: "error", category: "network" })
      Telemetry.captureException(error)
      throw new Error(error)
    }
    if (contentType?.includes("application/json") && !stream) {
      return await response.json().then((r) => r as T)
    }
    if (contentType?.includes("text/plain") && !stream) {
      return await response.text().then((r) => r as T)
    }
    if (contentType?.includes("application/octet-stream") || stream) {
      return response.body as unknown as T
    }
    return undefined as T
  }

  /**
   * Handle rate limit errors appropriately.
   * In some cases we want to throw an error (Reached message limit)
   * Other cases we will just wait and try again
   *
   * @param response the response from the server
   */
  private async onRateLimit(response: any) {
    const data = await response.json()
    if (data.error?.type === RateLimitMetric.Message || data.error?.type === RateLimitMetric.Ip) {
      const rlData: RateLimitData = {
        max: response.headers.get("X-RateLimit-Limit") as number,
        remaining: response.headers.get("X-RateLimit-Remaining") as number,
        retryAfter: response.headers.get("Retry-After") as number,
        resetTimestamp: response.headers.get("X-RateLimit-Reset") as number,
      }
      throw new RateLimitException(data.error, rlData)
    }

    this.rateLimitHandler?.handleRateLimit(true)
    await this.waitForRateLimit()
  }

  private async waitForRateLimit() {
    await sleep(10_000)
  }

  public setRateLimitHandler = (handler: RateLimitHandler) => (this.rateLimitHandler = handler)

  protected abstract getToken(): Promise<string | undefined>

  protected abstract reauthenticate(): Promise<void>
}

const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms))

enum RateLimitEndpoint {
  Chat = "/chat",
  Completions = "/completions",
  Embeddings = "/embeddings",
  EmbeddingsCustom = "/embeddings/custom",
}

const fetchErrors = {
  ENOTFOUND: "ENOTFOUND",
  ECONNRESET: "ECONNRESET",
  EADDRNOTAVAIL: "EADDRNOTAVAIL",
  EPIPE: "EPIPE",
  ECONNREFUSED: "ECONNREFUSED",
  EHOSTUNREACH: "EHOSTUNREACH",
  ECONNABORTED: "ECONNABORTED",
  EPROTO: "EPROTO",
  EAI_AGAIN: "EAI_AGAIN",
  ETIMEDOUT: "ETIMEDOUT",
}

const onFetchError = async (error: any) => {
  const dnsLookupError = error.code === fetchErrors.ENOTFOUND
  const addressNotAvailableError = error.code === fetchErrors.EADDRNOTAVAIL
  const connectionResetError = error.code === fetchErrors.ECONNRESET
  const pipeError = error.code === fetchErrors.EPIPE
  const dnsLookupTimedOutError = error.code === fetchErrors.EAI_AGAIN
  const connectionTimedOutError = error.code === fetchErrors.ETIMEDOUT
  const connectionRefusedError = error.code === fetchErrors.ECONNREFUSED
  const noRouteToHostError = error.code === fetchErrors.EHOSTUNREACH
  const connectionAbortedError = error.code === fetchErrors.ECONNABORTED
  const sslError = error.code === fetchErrors.EPROTO
  if (dnsLookupError || connectionResetError || addressNotAvailableError) {
    await sleep(7_500)
  } else if (pipeError || connectionRefusedError || noRouteToHostError || connectionAbortedError || sslError) {
    await sleep(1_000)
  } else if (dnsLookupTimedOutError || connectionTimedOutError) {
    await sleep(1_000)
  }
}
