import './config';

import { ExecutionContext, Global, Injectable, Module } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import {
  InjectThrottlerOptions,
  InjectThrottlerStorage,
  ThrottlerGuard,
  ThrottlerModule,
  type ThrottlerModuleOptions,
  ThrottlerOptions,
  ThrottlerOptionsFactory,
  ThrottlerStorageService,
} from '@nestjs/throttler';
import type { Request } from 'express';

import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request';
import type { ThrottlerType } from './config';
import { THROTTLER_PROTECTED, Throttlers } from './decorators';

@Injectable()
export class ThrottlerStorage extends ThrottlerStorageService {}

@Injectable()
class CustomOptionsFactory implements ThrottlerOptionsFactory {
  constructor(private readonly storage: ThrottlerStorage) {}

  createThrottlerOptions() {
    const options: ThrottlerModuleOptions = {
      throttlers: Object.entries(AFFiNE.throttler).map(([name, config]) => ({
        name,
        ...config,
      })),
      storage: this.storage,
    };

    return options;
  }
}

@Injectable()
export class CloudThrottlerGuard extends ThrottlerGuard {
  constructor(
    @InjectThrottlerOptions() options: ThrottlerModuleOptions,
    @InjectThrottlerStorage() storageService: ThrottlerStorage,
    reflector: Reflector,
    private readonly config: Config
  ) {
    super(options, storageService, reflector);
  }

  override getRequestResponse(context: ExecutionContext) {
    return getRequestResponseFromContext(context) as any;
  }

  override getTracker(req: Request): Promise<string> {
    return Promise.resolve(
      //           �� prefer session id if available
      `throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
      // ^ throttler prefix make the key in store recognizable
    );
  }

  override generateKey(
    context: ExecutionContext,
    tracker: string,
    throttler: string
  ) {
    if (tracker.endsWith(';custom')) {
      return `${tracker};${throttler}:${context.getClass().name}.${context.getHandler().name}`;
    }

    return `${tracker};${throttler}`;
  }

  override async handleRequest(
    context: ExecutionContext,
    limit: number,
    ttl: number,
    throttlerOptions: ThrottlerOptions
  ) {
    // give it 'default' if no throttler is specified,
    // so the unauthenticated users visits will always hit default throttler
    // authenticated users will directly bypass unprotected APIs in [CloudThrottlerGuard.canActivate]
    const throttler = this.getSpecifiedThrottler(context) ?? 'default';

    // by pass unmatched throttlers
    if (throttlerOptions.name !== throttler) {
      return true;
    }

    const { req, res } = this.getRequestResponse(context);
    const ignoreUserAgents =
      throttlerOptions.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents;
    if (Array.isArray(ignoreUserAgents)) {
      for (const pattern of ignoreUserAgents) {
        const ua = req.headers['user-agent'];
        if (ua && pattern.test(ua)) {
          return true;
        }
      }
    }

    let tracker = await this.getTracker(req);

    if (this.config.node.dev) {
      limit = Number.MAX_SAFE_INTEGER;
    } else {
      // custom limit or ttl APIs will be treated standalone
      if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) {
        tracker += ';custom';
      }
    }

    const key = this.generateKey(
      context,
      tracker,
      throttlerOptions.name ?? 'default'
    );
    const { timeToExpire, totalHits } = await this.storageService.increment(
      key,
      ttl
    );

    if (totalHits > limit) {
      res.header('Retry-After', timeToExpire.toString());
      await this.throwThrottlingException(context, {
        limit,
        ttl,
        key,
        tracker,
        totalHits,
        timeToExpire,
      });
    }

    res.header(`${this.headerPrefix}-Limit`, limit.toString());
    res.header(
      `${this.headerPrefix}-Remaining`,
      (limit - totalHits).toString()
    );
    res.header(`${this.headerPrefix}-Reset`, timeToExpire.toString());
    return true;
  }

  override async canActivate(context: ExecutionContext): Promise<boolean> {
    const { req } = this.getRequestResponse(context);

    const throttler = this.getSpecifiedThrottler(context);

    // if user is logged in, bypass non-protected handlers
    if (!throttler && req.user) {
      return true;
    }

    return super.canActivate(context);
  }

  getSpecifiedThrottler(context: ExecutionContext): ThrottlerType | undefined {
    const throttler = this.reflector.getAllAndOverride<Throttlers | undefined>(
      THROTTLER_PROTECTED,
      [context.getHandler(), context.getClass()]
    );

    return throttler === 'authenticated' ? undefined : throttler;
  }
}

@Global()
@Module({
  imports: [
    ThrottlerModule.forRootAsync({
      useClass: CustomOptionsFactory,
    }),
  ],
  providers: [ThrottlerStorage, CloudThrottlerGuard],
  exports: [ThrottlerStorage, CloudThrottlerGuard],
})
export class RateLimiterModule {}

export * from './decorators';