import re
from logging import Logger
from typing import Any, Dict, List

import numpy as np
import pandas as pd
from pulp import PULP_CBC_CMD, LpMinimize, LpProblem, LpVariable, lpSum
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.utils.acceleras_definitions import SensitivitySearch


class OptimizationSolution(BaseModel):
    class Config:
        arbitrary_types_allowed = True

    solution: pd.DataFrame
    max_ops: int
    min_ops: int
    limit_ops: int
    solution_ops: int
    percent: float
    ratio_a8w4: float
    ratio_a8w8: float
    ratio_a16w16: float

    def to_dict(self):
        solution_dict = self.solution.to_dict(orient="records")
        data = self.dict()
        data["solution"] = solution_dict
        return data


class OptimizerBase:
    """Base Class for all Mix Precision search Steps"""

    columns = ["layer_name", "precision", "snr_val", "ops", "metric"]
    max_precision = "a16_w16_a16"
    mim_precision = "a8_w4_a16"
    base_precision = "a8_w8_a16"
    max_ops: int
    min_ops: int
    base_ops: int

    def __init__(self, logger: Logger, jobs: int = 1):
        self.logger = logger
        self.jobs = jobs

    def load_problem(self, problem: Any):
        self.problem = problem.copy()
        self.max_ops = self.problem.query("precision == @self.max_precision")["ops"].sum()
        self.min_ops = self.problem.query("precision == @self.mim_precision")["ops"].sum()
        self.base_ops = self.problem.query("precision == @self.base_precision")["ops"].sum()
        if not all(self.problem.columns == self.columns):
            raise ValueError("Columns dont match the need it colums")
        return self

    def create_solution(self, solution: pd.DataFrame, limit: int):
        solution_ops = solution["ops"].sum()
        n_a8w4 = solution.query("precision == 'a8_w4_a16'")["ops"].sum()
        n_a8w8 = solution.query("precision == 'a8_w8_a16'")["ops"].sum()
        n_a16w16 = solution.query("precision == 'a16_w16_a16'")["ops"].sum()

        res = OptimizationSolution(
            solution=solution,
            max_ops=self.max_ops,
            min_ops=self.min_ops,
            limit_ops=limit,
            solution_ops=solution_ops,
            percent=solution_ops / self.base_ops,
            ratio_a8w4=n_a8w4 / solution_ops,
            ratio_a8w8=n_a8w8 / solution_ops,
            ratio_a16w16=n_a16w16 / solution_ops,
        )
        return res

    def run(self, *args): ...


class SegmentsCreator(OptimizerBase):
    """Creates Limits for the number of parameters based on the number of values"""

    problem: pd.DataFrame

    def run(self, segments: List[float]):
        segments = segments.copy()
        segments.sort(reverse=True)
        compresion_w = [int(self.base_ops * com) for com in segments]
        return compresion_w


class ParteoSolver(OptimizerBase):
    """Choose Mix Precision based on a Iterative parteo Curve"""

    def run(self, limit: int):
        sensitivy_list = self.problem.sort_values(["snr_val"])
        current_model = self.problem.query('precision == "a16_w16_a16"').copy()
        current = current_model.copy()
        for _, row in sensitivy_list.copy().iterrows():
            current.loc[current_model["layer_name"] == row["layer_name"], :] = row.values
            if current["ops"].sum() < limit:
                print(f"Current solution has: {current['ops'].sum()} < {limit} limit")
                break
        sol = self.create_solution(current_model, limit)
        return sol


class OptimizationProblem(BaseModel):
    n_ij: dict = Field(..., description="List of tuples representing the noise values for each option and layer")
    B: int = Field(..., description="List of budgets available for each layer in the optimization problem")
    c_ij: dict = Field(..., description="List of tuples representing the cost of choosing each option for each layer")
    encodings: Dict[int, str] = Field(description="Dictionary with mapings encodings to index")
    layers: Dict[int, str] = Field(description="Dictionary with mapings layers to index")


class LinearSolver(OptimizerBase):
    """
    Search the Mix Precision Solution by solving the linear programing problem:

    min(sum(noise))
    s.t
    params(solution) < limit

    """

    opt_problem: OptimizationProblem
    i_options: int
    j_layers: int
    prob = None
    x = None

    def run(self, limit: int):
        noise_problem = self.change_to_norm_noise(self.problem)
        self.transform_df_to_problem(noise_problem, limit)
        self.setup_problem()
        results = self.solve()
        res = self.create_solution(results, limit)
        return res

    def change_to_norm_noise(self, sensitivity_list: pd.DataFrame) -> pd.DataFrame:
        values = sensitivity_list.copy()
        values["snr_val"] = 1 / (10 ** (values["snr_val"] / 10))
        values["snr_val"] = values["snr_val"] / values["snr_val"].min()
        values["snr_val"] = 10 * np.log(values["snr_val"])
        self.problem = values
        return values

    def transform_df_to_problem(self, problem, layer_budget) -> OptimizationProblem:
        # Initialize n_ij and c_ij as dictionaries
        n_ij = {}
        c_ij = {}
        encodings = {}
        layers = {}
        df = problem.copy()
        # Iterate over unique encodings and layers
        for i, encoding in enumerate(df["precision"].unique()):
            for j, layer in enumerate(df["layer_name"].unique()):
                # Select row with the matching encoding and layer
                row = df[(df["precision"] == encoding) & (df["layer_name"] == layer)]
                if not row.empty:
                    n_ij[(i, j)] = row["snr_val"].values[0]
                    c_ij[(i, j)] = row["ops"].values[0]
                    encodings[i] = encoding
                    layers[j] = layer

        # Create the problem object
        problem = OptimizationProblem(n_ij=n_ij, B=layer_budget, c_ij=c_ij, encodings=encodings, layers=layers)
        self.opt_problem = problem
        self.i_options = len(problem.encodings)
        self.j_layers = len(problem.layers)

    def setup_problem(self):
        self.prob = LpProblem("Minimize Noise", LpMinimize)
        self.x = LpVariable.dicts(
            "x",
            [(i, j) for i in range(self.i_options) for j in range(self.j_layers)],
            0,
            1,
            cat="Binary",
        )

        # Objective function
        self.prob += lpSum(
            [
                self.opt_problem.n_ij[(i, j)] * self.x[(i, j)]
                for i in range(self.i_options)
                for j in range(self.j_layers)
            ],
        )

        # Constraints
        for j in range(self.j_layers):
            self.prob += lpSum([self.x[(i, j)] for i in range(self.i_options)]) == 1

        # Layer budget constraints

        self.prob += (
            lpSum(
                [
                    self.opt_problem.c_ij[(i, j)] * self.x[(i, j)]
                    for i in range(self.i_options)
                    for j in range(self.j_layers)
                ],
            )
            <= self.opt_problem.B
        )

    def solve(self):
        self.prob.solve(PULP_CBC_CMD(msg=0))

        selected_variables = []
        for var in self.x.values():
            if var.varValue == 1:
                indices = re.findall(r"\d+", var.name)
                indices = tuple(map(int, indices))
                layer_name = self.opt_problem.layers[indices[1]]
                precision = self.opt_problem.encodings[indices[0]]
                sol_extra = (
                    self.problem.query("layer_name == @layer_name and precision == @precision")
                    .reset_index(drop=True)
                    .loc[0, :]
                )
                selected_variables.append(
                    {
                        "layer_name": layer_name,
                        "precision": precision,
                        "snr_val": sol_extra["snr_val"],
                        "ops": sol_extra["ops"],
                        "metric": sol_extra["metric"],
                    },
                )

        return pd.DataFrame(selected_variables)


def mix_presion_solver_factory(searcher: SensitivitySearch) -> OptimizerBase.__class__:
    return {SensitivitySearch.LINEAR: LinearSolver, SensitivitySearch.PARETO: ParteoSolver}[searcher]
