import { assertNonNil } from "@deltagreen/utils"
import { Prisma, PrismaClient } from "@prisma/client"

export type ReplicaManagerOptions = {
  replicas: PrismaClient[]
}

export class ReplicaManager {
  private _replicaClients: PrismaClient[]

  constructor(options: ReplicaManagerOptions) {
    this._replicaClients = options.replicas
  }

  async connectAll() {
    await Promise.all(this._replicaClients.map((client) => client.$connect()))
  }

  async disconnectAll() {
    await Promise.all(this._replicaClients.map((client) => client.$disconnect()))
  }

  pickReplica(): PrismaClient {
    const client = this._replicaClients[Math.floor(Math.random() * this._replicaClients.length)]
    assertNonNil(client, "No replica clients available")
    return client
  }
}

export type ReplicasOptions = { replicas: PrismaClient[] }

const readOperations = [
  "findFirst",
  "findFirstOrThrow",
  "findMany",
  "findUnique",
  "findUniqueOrThrow",
  "groupBy",
  "aggregate",
  "count",
  "findRaw",
  "aggregateRaw",
]

export const readReplicas = (options: ReplicasOptions) =>
  Prisma.defineExtension((client) => {
    const datasourceName = Object.keys(options).find((key) => !key.startsWith("$"))
    if (!datasourceName) {
      throw new Error(`Read replicas options must specify a datasource`)
    }

    if (options.replicas.length === 0) {
      throw new Error(`At least one replica must be specified`)
    }

    const replicaManager = new ReplicaManager({ replicas: options.replicas })

    return client.$extends({
      client: {
        $primary<T extends object>(this: T): Omit<T, "$primary" | "$replica"> {
          const context = Prisma.getExtensionContext(this)
          // If we're in a transaction, the current client is connected to the primary.
          if (!("$transaction" in context && typeof context.$transaction === "function")) {
            return context
          }

          return client as unknown as Omit<T, "$primary" | "$replica">
        },

        $replica<T extends object>(this: T): Omit<T, "$primary" | "$replica"> {
          const context = Prisma.getExtensionContext(this)
          // If we're in a transaction, the current client is connected to the primary.
          if (!("$transaction" in context && typeof context.$transaction === "function")) {
            throw new Error(`Cannot use $replica inside of a transaction`)
          }

          return replicaManager.pickReplica() as unknown as Omit<T, "$primary" | "$replica">
        },

        async $connect() {
          await Promise.all([client.$connect(), replicaManager.connectAll()])
        },

        async $disconnect() {
          await Promise.all([client.$disconnect(), replicaManager.disconnectAll()])
        },
      },

      query: {
        $allOperations({
          args,
          model,
          operation,
          query,
          // @ts-expect-error - __internalParams is not part of the public API
          __internalParams: { transaction },
        }) {
          if (transaction) {
            return query(args)
          }

          if (readOperations.includes(operation)) {
            const replica = replicaManager.pickReplica()
            if (model) {
              // @ts-expect-error - TODO
              return replica[model][operation](args)
            }

            // @ts-expect-error - TODO
            return replica[operation](args)
          }

          return query(args)
        },
      },
    })
  })
