"""
Copyright 2025, the CVXPY authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

SOC dimension reduction to 3D cones.

This module provides the SOCDim3 reduction which converts arbitrary-dimensional
Second-Order Cone (SOC) constraints to equivalent systems of 3D SOC constraints.
This enables solvers that only support 3D SOC to handle arbitrary dimensional
SOC constraints.

The decomposition uses a binary tree structure:
- Dimension 1: Convert to NonNeg constraint (||[]|| <= t becomes t >= 0)
- Dimension 2: Convert to NonNeg constraints (|x| <= t becomes t-x >= 0, t+x >= 0)
- Dimension 3: Pass through unchanged
- Dimension 4: Chain of two 3D cones (special case for efficiency)
- Dimension n > 4: Binary split into balanced tree of 3D cones

Example:
    A dimension 5 cone ||(x1, x2, x3, x4)|| <= t becomes::

        ||(x1, x2)|| <= s1  (3D cone)
        ||(x3, x4)|| <= s2  (3D cone)
        ||(s1, s2)|| <= t   (3D cone)

    where s1, s2 are auxiliary variables.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

if TYPE_CHECKING:
    from cvxpy.problems.problem import Problem

import numpy as np

from cvxpy.atoms.affine.reshape import reshape
from cvxpy.atoms.affine.vstack import vstack
from cvxpy.constraints.nonpos import NonNeg
from cvxpy.constraints.second_order import SOC
from cvxpy.expressions import cvxtypes
from cvxpy.expressions.expression import Expression
from cvxpy.expressions.variable import Variable
from cvxpy.reductions.cone2cone.cone_tree import (
    LeafNode,
    SpecialNode,
    SplitNode,
    TreeNode,
    get_all_cone_ids,
)
from cvxpy.reductions.reduction import Reduction
from cvxpy.reductions.solution import Solution

# =============================================================================
# Inverse Data Dataclasses
# =============================================================================

@dataclass
class SOCTreeData:
    """Decomposition data for a single original SOC constraint.

    Attributes
    ----------
    trees : List[TreeNode]
        List of decomposition trees, one per cone in elementwise SOC.
    num_cones : int
        Number of individual cones (1 for simple SOC, >1 for elementwise).
    original_dim : int
        Dimension of each cone.
    axis : int
        Axis parameter from original SOC constraint.
    """
    trees: List[TreeNode]
    num_cones: int
    original_dim: int
    axis: int = 0


@dataclass
class SOCDim3InverseData:
    """Data needed to reconstruct original SOC duals from decomposed problem.

    Attributes
    ----------
    soc_trees : Dict[int, SOCTreeData]
        Mapping from original SOC constraint ID to its decomposition data.
    old_constraints : List
        Reference to original problem constraints.
    """
    soc_trees: Dict[int, SOCTreeData] = field(default_factory=dict)
    old_constraints: List = field(default_factory=list)


# =============================================================================
# Helper Functions
# =============================================================================

def _to_scalar_shape(expr: Expression) -> Expression:
    """Reshape expression to scalar form (1,) for SOC constructor."""
    return reshape(expr, (1,), order='F')


def _get_flat_dual(dual_value: Optional[Any]) -> Optional[np.ndarray]:
    """Convert a dual value to flat array format.

    Dual values can come in two formats:
    - Flat array: [lambda, mu1, mu2, ...] (from solver)
    - Split list: [lambda_array, mu_array] (after save_dual_value)

    This function normalizes to flat array format.
    """
    if dual_value is None:
        return None

    # Check if it's a list with two elements (split format from save_dual_value)
    if isinstance(dual_value, list) and len(dual_value) == 2:
        t_val, x_val = dual_value
        if t_val is None or x_val is None:
            return None
        t_arr = np.atleast_1d(t_val)
        x_arr = np.atleast_1d(x_val).flatten(order='F')
        return np.concatenate([t_arr[:1], x_arr])

    # Otherwise assume it's already a flat array
    return np.atleast_1d(dual_value)


# =============================================================================
# Decomposition Functions
# =============================================================================

def _decompose_soc_single(
    t_expr: Expression,
    x_expr: Expression,
    soc3_out: List[SOC],
    nonneg_out: List[NonNeg]
) -> TreeNode:
    """Decompose a single ||x|| <= t constraint into tree of exactly 3D cones.

    Parameters
    ----------
    t_expr : Expression
        The scalar bound expression (should have size 1).
    x_expr : Expression
        The vector argument (1D expression).
    soc3_out : List[SOC]
        Output list to append generated 3D SOC constraints to.
    nonneg_out : List[NonNeg]
        Output list to append generated NonNeg constraints to.

    Returns
    -------
    TreeNode
        Tree structure for dual variable reconstruction.
    """
    if t_expr is None or x_expr is None:
        raise ValueError("t_expr and x_expr cannot be None")

    n = x_expr.size  # Number of elements in x

    # Dimension 1: ||[]|| <= t  ->  t >= 0 (degenerate case)
    if n == 0:
        c = NonNeg(t_expr)
        nonneg_out.append(c)
        return SpecialNode(
            node_type='nonneg_dim1',
            cone_ids=(c.id,),
            var_indices=()
        )

    # Dimension 2: |x| <= t  ->  t - x >= 0, t + x >= 0 (no SOC needed)
    if n == 1:
        x_flat = x_expr.flatten(order='F')
        c1 = NonNeg(t_expr - x_flat)  # t - x >= 0
        c2 = NonNeg(t_expr + x_flat)  # t + x >= 0
        nonneg_out.extend([c1, c2])
        return SpecialNode(
            node_type='nonneg_dim2',
            cone_ids=(c1.id, c2.id),
            var_indices=(0,)
        )

    # Dimension 3: Already valid, pass through unchanged
    if n == 2:
        cone = SOC(_to_scalar_shape(t_expr), x_expr.flatten(order='F'))
        soc3_out.append(cone)
        return LeafNode(cone_id=cone.id, var_indices=(0, 1))

    # Dimension 4: Special chain structure to avoid unbalanced 1+2 split
    if n == 3:
        s = Variable()
        nonneg_out.append(NonNeg(s))
        x_left = x_expr[:2]
        x_last = x_expr[2]

        # First cone: ||(x1, x2)|| <= s
        cone1 = SOC(_to_scalar_shape(s), x_left.flatten(order='F'))
        soc3_out.append(cone1)

        # Second cone: ||(s, x3)|| <= t
        root_x = vstack([s, x_last])
        cone2 = SOC(_to_scalar_shape(t_expr), root_x.flatten(order='F'))
        soc3_out.append(cone2)

        return SpecialNode(
            node_type='chain_dim4',
            cone_ids=(cone1.id, cone2.id),
            var_indices=(0, 1, 2),
            metadata={'leaf_cone_id': cone1.id, 'root_cone_id': cone2.id}
        )

    # n >= 4: Standard binary split
    mid = (n + 1) // 2
    x_left = x_expr[:mid]
    x_right = x_expr[mid:]

    # Create auxiliary variables
    s1 = Variable()
    s2 = Variable()
    nonneg_out.append(NonNeg(s1))
    nonneg_out.append(NonNeg(s2))

    # Recursively decompose each half
    left_tree = _decompose_soc_single(s1, x_left.flatten(order='F'), soc3_out, nonneg_out)
    right_tree = _decompose_soc_single(s2, x_right.flatten(order='F'), soc3_out, nonneg_out)

    # Root constraint: ||(s1, s2)|| <= t (exactly 3D)
    root_x = vstack([s1, s2])
    root_cone = SOC(_to_scalar_shape(t_expr), root_x.flatten(order='F'))
    soc3_out.append(root_cone)

    return SplitNode(
        cone_id=root_cone.id,
        left=left_tree,
        right=right_tree
    )


# =============================================================================
# Dual Reconstruction Functions
# =============================================================================

def _collect_x_duals_into_array(
    tree: TreeNode,
    dual_vars: Dict[int, Any],
    out: np.ndarray,
    offset: int
) -> Optional[int]:
    """Collect x-component duals from leaf cones into pre-allocated array."""
    if isinstance(tree, SpecialNode):
        if tree.node_type == 'nonneg_dim1':
            return offset
        elif tree.node_type == 'nonneg_dim2' and len(tree.cone_ids) == 2:
            alpha_raw = dual_vars.get(tree.cone_ids[0])
            beta_raw = dual_vars.get(tree.cone_ids[1])
            if alpha_raw is None or beta_raw is None:
                return None
            alpha = float(np.atleast_1d(alpha_raw).flat[0])
            beta = float(np.atleast_1d(beta_raw).flat[0])
            out[offset] = beta - alpha
            return offset + 1
        elif tree.node_type == 'chain_dim4':
            leaf_id = tree.metadata.get('leaf_cone_id')
            root_id = tree.metadata.get('root_cone_id')
            if leaf_id is None or root_id is None:
                return None

            leaf_dual = _get_flat_dual(dual_vars.get(leaf_id))
            root_dual = _get_flat_dual(dual_vars.get(root_id))
            if leaf_dual is None or root_dual is None:
                return None

            if len(leaf_dual) >= 3:
                out[offset] = leaf_dual[1]
                out[offset + 1] = leaf_dual[2]
                offset += 2

            if len(root_dual) >= 3:
                out[offset] = root_dual[2]
                offset += 1

            return offset
        return offset

    if isinstance(tree, LeafNode):
        dual_raw = dual_vars.get(tree.cone_id)
        if dual_raw is None:
            return None

        dual = _get_flat_dual(dual_raw)
        if dual is None or len(dual) < 2:
            return None

        x_dual = dual[1:]
        n_write = len(x_dual)
        out[offset:offset + n_write] = x_dual
        return offset + n_write

    if isinstance(tree, SplitNode):
        new_offset = _collect_x_duals_into_array(tree.left, dual_vars, out, offset)
        if new_offset is None:
            return None
        return _collect_x_duals_into_array(tree.right, dual_vars, out, new_offset)

    return None


def _get_root_t_dual(tree: TreeNode, dual_vars: Dict[int, Any]) -> Optional[float]:
    """Get the t-component dual (lambda) from the root cone."""
    if isinstance(tree, SpecialNode):
        if tree.node_type == 'nonneg_dim1':
            if len(tree.cone_ids) >= 1:
                dual_raw = dual_vars.get(tree.cone_ids[0])
                if dual_raw is not None:
                    return float(np.atleast_1d(dual_raw).flat[0])
            return None
        elif tree.node_type == 'nonneg_dim2' and len(tree.cone_ids) == 2:
            alpha_raw = dual_vars.get(tree.cone_ids[0])
            beta_raw = dual_vars.get(tree.cone_ids[1])
            if alpha_raw is None or beta_raw is None:
                return None
            alpha = float(np.atleast_1d(alpha_raw).flat[0])
            beta = float(np.atleast_1d(beta_raw).flat[0])
            return alpha + beta
        elif tree.node_type == 'chain_dim4':
            root_id = tree.metadata.get('root_cone_id')
            if root_id is None:
                return None
            dual = _get_flat_dual(dual_vars.get(root_id))
            return dual[0] if dual is not None and len(dual) >= 1 else None
        return None

    if isinstance(tree, LeafNode):
        dual = _get_flat_dual(dual_vars.get(tree.cone_id))
        return dual[0] if dual is not None and len(dual) >= 1 else None

    if isinstance(tree, SplitNode):
        dual = _get_flat_dual(dual_vars.get(tree.cone_id))
        return dual[0] if dual is not None and len(dual) >= 1 else None

    return None


def _get_original_dim(tree: TreeNode) -> int:
    """Get the original dimension of a tree."""
    if isinstance(tree, SpecialNode):
        if tree.node_type == 'nonneg_dim1':
            return 1
        elif tree.node_type == 'nonneg_dim2':
            return 2
        elif tree.node_type == 'chain_dim4':
            return 4
        return len(tree.var_indices) + 1
    if isinstance(tree, LeafNode):
        return 3
    if isinstance(tree, SplitNode):
        return _get_original_dim(tree.left) + _get_original_dim(tree.right) - 1
    return 0


def _reconstruct_soc_dual(
    tree: TreeNode,
    dual_vars: Dict[int, Any]
) -> Optional[np.ndarray]:
    """Reconstruct the original SOC dual from decomposed cone duals."""
    t_dual = _get_root_t_dual(tree, dual_vars)
    if t_dual is None:
        return None

    original_dim = _get_original_dim(tree)
    x_size = original_dim - 1
    if x_size == 0:
        return np.array([t_dual])

    x_duals = np.empty(x_size)
    final_offset = _collect_x_duals_into_array(tree, dual_vars, x_duals, 0)

    if final_offset is None:
        return None

    return np.concatenate([[t_dual], x_duals])


def _get_all_tree_cone_ids(tree: TreeNode) -> Set[int]:
    """Get all constraint IDs from a tree, including SpecialNode cases."""
    if isinstance(tree, SpecialNode):
        return set(tree.cone_ids)
    return get_all_cone_ids(tree)


# =============================================================================
# Main Reduction Class
# =============================================================================

class SOCDim3(Reduction):
    """Convert n-dimensional SOC constraints to dimension-3 SOC constraints."""

    def accepts(self, problem: Problem) -> bool:
        """Check if this reduction accepts the given problem."""
        return True

    def apply(self, problem: Problem) -> Tuple[Problem, SOCDim3InverseData]:
        """Apply SOCDim3 decomposition to all SOC constraints."""
        inverse_data = SOCDim3InverseData(
            soc_trees={},
            old_constraints=problem.constraints
        )

        has_soc = any(isinstance(c, SOC) for c in problem.constraints)
        if not has_soc:
            return problem, inverse_data

        new_constraints: List = []

        for con in problem.constraints:
            if isinstance(con, SOC):
                t = con.args[0]
                X = con.args[1]
                axis = con.axis

                if axis == 1:
                    X = X.T

                if len(X.shape) <= 1:
                    X_reshaped = reshape(X, (X.size, 1), order='F')
                    t_reshaped = _to_scalar_shape(t)
                else:
                    X_reshaped = X
                    t_reshaped = t

                num_cones = t_reshaped.size

                if X_reshaped.shape[1] != num_cones and num_cones > 1:
                    raise ValueError(
                        f"Dimension mismatch: t has {num_cones} elements but "
                        f"X has {X_reshaped.shape[1]} columns"
                    )

                all_trees: List[TreeNode] = []
                soc3_out: List[SOC] = []
                nonneg_out: List[NonNeg] = []

                for i in range(num_cones):
                    t_i = t_reshaped[i] if num_cones > 1 else t_reshaped[0]
                    x_i = X_reshaped[:, i] if num_cones > 1 else X_reshaped[:, 0]

                    tree = _decompose_soc_single(t_i, x_i, soc3_out, nonneg_out)
                    all_trees.append(tree)

                new_constraints.extend(soc3_out)
                new_constraints.extend(nonneg_out)

                cone_sizes = con.cone_sizes()
                inverse_data.soc_trees[con.id] = SOCTreeData(
                    trees=all_trees,
                    num_cones=num_cones,
                    original_dim=cone_sizes[0] if cone_sizes else 3,
                    axis=axis
                )
            else:
                new_constraints.append(con)

        new_problem = cvxtypes.problem()(problem.objective, new_constraints)
        return new_problem, inverse_data

    def invert(
        self,
        solution: Solution,
        inverse_data: SOCDim3InverseData
    ) -> Solution:
        """Reconstruct solution with original SOC dual variables."""
        pvars = solution.primal_vars.copy() if solution.primal_vars else {}

        if not solution.dual_vars:
            return Solution(solution.status, solution.opt_val, pvars, {}, solution.attr)

        dvars: Dict[int, Any] = {}

        # Identify which constraint IDs belong to decomposed cones
        decomposed_cone_ids: Set[int] = set()
        for tree_data in inverse_data.soc_trees.values():
            for tree in tree_data.trees:
                decomposed_cone_ids.update(_get_all_tree_cone_ids(tree))

        # Copy non-decomposed duals
        for cid, dual in solution.dual_vars.items():
            if cid not in decomposed_cone_ids:
                dvars[cid] = dual

        # Reconstruct SOC duals
        for orig_id, tree_data in inverse_data.soc_trees.items():
            trees = tree_data.trees
            if len(trees) == 1:
                reconstructed = _reconstruct_soc_dual(trees[0], solution.dual_vars)
                if reconstructed is not None:
                    dvars[orig_id] = reconstructed
            else:
                all_duals: List[np.ndarray] = []
                success = True
                for tree in trees:
                    reconstructed = _reconstruct_soc_dual(tree, solution.dual_vars)
                    if reconstructed is None:
                        success = False
                        break
                    all_duals.append(reconstructed)

                if success:
                    t_duals = np.array([d[0] for d in all_duals])
                    x_duals = np.column_stack([d[1:] for d in all_duals])
                    reshaped = np.column_stack([t_duals[:, np.newaxis], x_duals.T])
                    dvars[orig_id] = reshaped.flatten()

        return Solution(solution.status, solution.opt_val, pvars, dvars, solution.attr)
