import {
  GenerationType,
  InferenceParameters,
  InferenceStyleParameters,
} from "../../../../types";
import { FORM_FIELDS_2D, ReferenceImageType } from "../../constants";
import {
  formatReferenceImageForAPI,
  getImageObject,
  getPercentageFromWeight,
  getWeightFromPercentage,
  guidanceScaleFromPromptStrength,
  guidanceScaleToPromptStrength,
} from "../../utils/utils";

export const FORM_FIELDS_UPSCALE = {
  STYLE: "style",
  PROMPT_LANGUAGE: "promptLanguage",
  PROMPT_TEXT: "promptText",
  NEGATIVE_PROMPT_TEXT: "negativePromptText",
  SEED: "seed",
  PROMPT_STRENGTH: "promptStrength",
  RESEMBLANCE: "resemblance",
  CREATIVITY: "creativity",
  TARGET_FILE: "targetFile",
  UPSCALE_FACTOR: "upscaleFactor",
  DENOISING_STEPS: "denoisingSteps",
} as const;

// Create a type that includes all the values from FORM_FIELDS_UPSCALE
export type UpscaleFieldValuesType =
  (typeof FORM_FIELDS_UPSCALE)[keyof typeof FORM_FIELDS_UPSCALE];

export interface FormTypeUpscale {
  [FORM_FIELDS_UPSCALE.STYLE]: InferenceStyleParameters;
  [FORM_FIELDS_UPSCALE.TARGET_FILE]: ReferenceImageType;
  [FORM_FIELDS_UPSCALE.UPSCALE_FACTOR]: number;
  [FORM_FIELDS_UPSCALE.PROMPT_LANGUAGE]: string;
  [FORM_FIELDS_UPSCALE.PROMPT_TEXT]?: string;
  [FORM_FIELDS_UPSCALE.NEGATIVE_PROMPT_TEXT]?: string;
  [FORM_FIELDS_UPSCALE.RESEMBLANCE]: number;
  [FORM_FIELDS_UPSCALE.CREATIVITY]: number;
  [FORM_FIELDS_UPSCALE.SEED]?: number;
  [FORM_FIELDS_UPSCALE.PROMPT_STRENGTH]: number;
  [FORM_FIELDS_UPSCALE.DENOISING_STEPS]: number;
}

export const DEFAULT_UPSCALE_FORM_VALUES: FormTypeUpscale = {
  [FORM_FIELDS_UPSCALE.STYLE]: null,
  [FORM_FIELDS_UPSCALE.TARGET_FILE]: null,
  [FORM_FIELDS_UPSCALE.UPSCALE_FACTOR]: 2,
  [FORM_FIELDS_UPSCALE.PROMPT_LANGUAGE]: null,
  [FORM_FIELDS_UPSCALE.PROMPT_TEXT]: "",
  [FORM_FIELDS_UPSCALE.NEGATIVE_PROMPT_TEXT]: "",
  [FORM_FIELDS_UPSCALE.RESEMBLANCE]: 60,
  [FORM_FIELDS_UPSCALE.CREATIVITY]: 35,
  [FORM_FIELDS_UPSCALE.SEED]: null,
  [FORM_FIELDS_UPSCALE.PROMPT_STRENGTH]: 50,
  [FORM_FIELDS_UPSCALE.DENOISING_STEPS]: 10,
};

export const getDefaultUpscaleFormValues = () =>
  JSON.parse(JSON.stringify(DEFAULT_UPSCALE_FORM_VALUES));

export const formatUpscaleFormValuesForAPI = (
  inputForm: FormTypeUpscale,
  defaultParameters: InferenceParameters,
): InferenceParameters => {
  const parameters = {
    ...defaultParameters,
    generationType: "UPSCALE" as GenerationType,
    batchSize: 1,
    upscaleRatio: inputForm[FORM_FIELDS_UPSCALE.UPSCALE_FACTOR],
    prompt: inputForm[FORM_FIELDS_UPSCALE.PROMPT_TEXT],
    negativePrompt: inputForm[FORM_FIELDS_UPSCALE.NEGATIVE_PROMPT_TEXT],
    promptLanguage: inputForm[FORM_FIELDS_UPSCALE.PROMPT_LANGUAGE],
    seed: inputForm[FORM_FIELDS_UPSCALE.SEED] || -1,
    creativity: getWeightFromPercentage(
      inputForm[FORM_FIELDS_UPSCALE.CREATIVITY],
    ),
    files: [],
    styles: [],
  };
  if (inputForm[FORM_FIELDS_UPSCALE.TARGET_FILE]) {
    parameters.files.push(
      formatReferenceImageForAPI(inputForm[FORM_FIELDS_UPSCALE.TARGET_FILE]),
    );
  }
  if (inputForm[FORM_FIELDS_2D.STYLE]) {
    parameters.styles.push({
      id: inputForm[FORM_FIELDS_2D.STYLE].id,
      weight: getWeightFromPercentage(inputForm[FORM_FIELDS_2D.STYLE].weight),
    });
  }
  if (parameters.creativity > 0) {
    parameters.resemblance = getWeightFromPercentage(
      inputForm[FORM_FIELDS_UPSCALE.RESEMBLANCE],
    );
    parameters.guidanceScale = guidanceScaleFromPromptStrength(
      inputForm[FORM_FIELDS_UPSCALE.PROMPT_STRENGTH],
    );
    parameters.numInferenceSteps =
      inputForm[FORM_FIELDS_UPSCALE.DENOISING_STEPS];
  }

  return parameters;
};

export const formatInferenceParamsForFormTypeUpscale = async (
  params: InferenceParameters,
): Promise<Partial<FormTypeUpscale>> => {
  const outputForm: Partial<FormTypeUpscale> = {};

  if (params) {
    if (params.files?.length) {
      const file = params.files[0];
      const imageObject = await getImageObject(file.url);
      outputForm[FORM_FIELDS_UPSCALE.TARGET_FILE] = {
        src: file.url,
        width: imageObject.width,
        height: imageObject.height,
      };
    }
    if (params.styles?.length) {
      outputForm[FORM_FIELDS_UPSCALE.STYLE] = params.styles[0];
    }
    if (typeof params.upscaleRatio !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.UPSCALE_FACTOR] = params.upscaleRatio;
    }
    if (typeof params.prompt !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.PROMPT_TEXT] = params.prompt;
    }
    if (typeof params.negativePrompt !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.NEGATIVE_PROMPT_TEXT] =
        params.negativePrompt;
    }
    if (typeof params.promptLanguage !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.PROMPT_LANGUAGE] = params.promptLanguage;
    }
    if (typeof params.resemblance !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.RESEMBLANCE] = getPercentageFromWeight(
        params.resemblance,
      );
    }
    if (typeof params.creativity !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.CREATIVITY] = getPercentageFromWeight(
        params.creativity,
      );
    }
    if (typeof params.guidanceScale !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.PROMPT_STRENGTH] =
        guidanceScaleToPromptStrength(params.guidanceScale);
    }
    if (typeof params.numInferenceSteps !== "undefined") {
      outputForm[FORM_FIELDS_UPSCALE.DENOISING_STEPS] =
        params.numInferenceSteps;
    }
  }
  return outputForm;
};
