r"""Finders
===========
"""

import os
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any

from jinja2 import Template
from jsonschema.protocols import Validator
from jsonschema.validators import validator_for
from lsprotocol.types import (
    Diagnostic,
    DiagnosticSeverity,
    Location,
    Position,
    Range,
    TextEdit,
)
from tree_sitter import Node, Query, QueryCursor, Tree

from . import UNI, Finder
from .schema import Trie


@dataclass
class ErrorFinder(Finder):
    r"""Errorfinder."""

    message: str = "{{uni.text}}: error"
    severity: DiagnosticSeverity = DiagnosticSeverity.Error

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        node = uni.node
        return node.has_error and not (
            any(child.has_error for child in node.children)
        )


@dataclass
class MissingFinder(Finder):
    r"""Missingfinder."""

    message: str = "{{uni.text}}: missing"
    severity: DiagnosticSeverity = DiagnosticSeverity.Error

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        node = uni.node
        return node.is_missing and not (
            any(child.is_missing for child in node.children)
        )


@dataclass
class NotFileFinder(Finder):
    r"""NotFilefinder."""

    message: str = "{{uni.text}}: no such file or directory"
    severity: DiagnosticSeverity = DiagnosticSeverity.Error

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        path = uni.path
        return not (os.path.isfile(path) or os.path.isdir(path))


@dataclass
class RepeatedFinder(Finder):
    r"""Repeatedfinder."""

    message: str = "{{uni.text}}: is repeated on {{_uni}}"
    severity: DiagnosticSeverity = DiagnosticSeverity.Warning
    repeated_unis: list[UNI] = field(default_factory=list)
    uni_pairs: list[tuple[UNI, UNI]] = field(default_factory=list)

    def reset(self) -> None:
        r"""Reset.

        :rtype: None
        """
        self.level = 0
        self.unis = []
        self.repeated_unis = []
        self.uni_pairs = []

    def filter(self, uni: UNI) -> bool:
        r"""Filter.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        return True

    def compare(self, uni: UNI, _uni: UNI) -> bool:
        r"""Compare.

        :param uni:
        :type uni: UNI
        :param _uni:
        :type _uni: UNI
        :rtype: bool
        """
        return uni.node.text == _uni.node.text

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        if self.filter(uni) is False:
            return False
        for _uni in self.repeated_unis:
            if self.compare(uni, _uni):
                self.uni_pairs += [(uni, _uni)]
                return True
        self.repeated_unis += [uni]
        return False

    def get_definitions(self, uni: UNI) -> list[Location]:
        r"""Get definitions.

        :param uni:
        :type uni: UNI
        :rtype: list[Location]
        """
        for uni_, _uni in self.uni_pairs:
            # cache hit
            if uni == uni_:
                return [_uni.location]
        return []

    def get_references(self, uni: UNI) -> list[Location]:
        r"""Get references.

        :param uni:
        :type uni: UNI
        :rtype: list[Location]
        """
        locations = []
        for uni_, _uni in self.uni_pairs:
            # cache hit
            if uni == _uni:
                locations += [uni_.location]
        return locations

    def get_text_edits(self, uri: str, tree: Tree) -> list[TextEdit]:
        r"""Get text edits. Only return two to avoid `Overlapping edit`

        :param self:
        :param uri:
        :type uri: str
        :param tree:
        :type tree: Tree
        :rtype: list[TextEdit]
        """
        self.find_all(uri, tree)
        for uni, _uni in self.uni_pairs:
            # swap 2 unis
            return [
                uni.get_text_edit(_uni.text),
                _uni.get_text_edit(uni.text),
            ]
        return []

    def uni2diagnostic(self, uni: UNI) -> Diagnostic:
        r"""Uni2diagnostic.

        :param uni:
        :type uni: UNI
        :rtype: Diagnostic
        """
        for uni_, _uni in self.uni_pairs:
            if uni == uni_:
                return uni.get_diagnostic(
                    self.message, self.severity, _uni=_uni
                )
        return uni.get_diagnostic(self.message, self.severity)


@dataclass
class UnsortedFinder(RepeatedFinder):
    r"""Unsortedfinder."""

    message: str = "{{uni.text}}: is unsorted due to {{_uni}}"
    severity: DiagnosticSeverity = DiagnosticSeverity.Warning

    def compare(self, uni: UNI, _uni: UNI) -> bool:
        r"""Compare.

        :param uni:
        :type uni: UNI
        :param _uni:
        :type _uni: UNI
        :rtype: bool
        """
        if uni.node.text and _uni.node.text:
            return uni.node.text < _uni.node.text
        return True


@dataclass(init=False)
class UnFixedOrderFinder(RepeatedFinder):
    r"""Unfixedorderfinder."""

    def __init__(
        self,
        order: list[Any],
        message: str = "{{uni.text}}: is unsorted due to {{_uni}}",
        severity: DiagnosticSeverity = DiagnosticSeverity.Warning,
    ) -> None:
        r"""Init.

        :param order:
        :type order: list[Any]
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        super().__init__(message, severity)
        self.order = order

    def filter(self, uni: UNI) -> bool:
        r"""Filter.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        return uni.text in self.order

    def compare(self, uni: UNI, _uni: UNI) -> bool:
        r"""Compare.

        :param uni:
        :type uni: UNI
        :param _uni:
        :type _uni: UNI
        :rtype: bool
        """
        return self.order.index(uni.text) < self.order.index(_uni.text)


@dataclass(init=False)
class TypeFinder(Finder):
    r"""Typefinder."""

    def __init__(
        self,
        type: str,
        message: str = "",
        severity: DiagnosticSeverity = DiagnosticSeverity.Information,
    ) -> None:
        r"""Init.

        :param type:
        :type type: str
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        super().__init__(message, severity)
        self.type = type

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        node = uni.node
        return node.type == self.type


@dataclass(init=False)
class PositionFinder(Finder):
    r"""Positionfinder."""

    def __init__(
        self,
        position: Position,
        left_equal: bool = True,
        right_equal: bool = False,
        message: str = "",
        severity: DiagnosticSeverity = DiagnosticSeverity.Information,
    ) -> None:
        r"""Init.

        :param position:
        :type position: Position
        :param left_equal:
        :type left_equal: bool
        :param right_equal:
        :type right_equal: bool
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        super().__init__(message, severity)
        self.position = position
        self.left_equal = left_equal
        self.right_equal = right_equal

    @staticmethod
    def belong(
        position: Position,
        node: Node,
        left_equal: bool = True,
        right_equal: bool = False,
    ) -> bool:
        r"""Belong.

        :param position:
        :type position: Position
        :param node:
        :type node: Node
        :param left_equal:
        :type left_equal: bool
        :param right_equal:
        :type right_equal: bool
        :rtype: bool
        """
        if left_equal:
            left_flag = Position(*node.start_point) <= position
        else:
            left_flag = Position(*node.start_point) < position
        if right_equal:
            right_flag = position <= Position(*node.end_point)
        else:
            right_flag = position < Position(*node.end_point)
        return left_flag and right_flag

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        node = uni.node
        return node.child_count == 0 and self.belong(
            self.position, node, self.left_equal, self.right_equal
        )


@dataclass(init=False)
class RangeFinder(Finder):
    r"""Rangefinder."""

    def __init__(
        self,
        range: Range,
        message: str = "",
        severity: DiagnosticSeverity = DiagnosticSeverity.Information,
    ) -> None:
        r"""Init.

        :param range:
        :type range: Range
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        super().__init__(message, severity)
        self.range = range

    @staticmethod
    def equal(_range: Range, node: Node) -> bool:
        r"""Equal.

        :param _range:
        :type _range: Range
        :param node:
        :type node: Node
        :rtype: bool
        """
        return _range.start == Position(
            *node.start_point
        ) and _range.end == Position(*node.end_point)

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        node = uni.node
        return self.equal(self.range, node)


@dataclass(init=False)
class RequiresFinder(Finder):
    r"""Requiresfinder."""

    def __init__(
        self,
        requires: set[Any],
        message: str = "{{require}}: required",
        severity: DiagnosticSeverity = DiagnosticSeverity.Error,
    ) -> None:
        r"""Init.

        :param requires:
        :type requires: set[Any]
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        self.initial_requires = requires
        self.requires = deepcopy(self.initial_requires)
        super().__init__(message, severity)

    def reset(self) -> None:
        r"""Reset.

        :rtype: None
        """
        self.level = 0
        self.unis = []
        self.requires = deepcopy(self.initial_requires)

    def filter(self, uni: UNI, require: Any) -> bool:
        r"""Filter.

        :param uni:
        :type uni: UNI
        :param require:
        :type require: Any
        :rtype: bool
        """
        return False

    def __call__(self, uni: UNI) -> bool:
        r"""Call.

        :param uni:
        :type uni: UNI
        :rtype: bool
        """
        found = set()
        for require in self.requires:
            if self.filter(uni, require):
                found |= {require}
        self.requires -= found
        return False

    def require2message(self, require: Any, **kwargs: Any) -> str:
        r"""Require2message.

        :param require:
        :type require: Any
        :param kwargs:
        :type kwargs: Any
        :rtype: str
        """
        return Template(self.message).render(
            uni=self, require=require, **kwargs
        )

    def get_diagnostics(self, uri: str, tree: Tree) -> list[Diagnostic]:
        r"""Get diagnostics.

        :param uri:
        :type uri: str
        :param tree:
        :type tree: Tree
        :rtype: list[Diagnostic]
        """
        self.find_all(uri, tree)
        return [
            Diagnostic(
                # If you want to specify a range that contains a line including
                # the line ending character(s) then use an end position
                # denoting the start of the next line
                Range(Position(0, 0), Position(1, 0)),
                self.require2message(i),
                self.severity,
            )
            for i in self.requires
        ]


@dataclass(init=False)
class SchemaFinder(Finder):
    r"""Schemafinder."""

    def __init__(self, schema: dict[str, Any], cls: type[Trie]) -> None:
        r"""Init.

        :param schema:
        :type schema: dict[str, Any]
        :param cls:
        :type cls: type[Trie]
        :rtype: None
        """
        self.validator = self.schema2validator(schema)
        self.cls = cls

    @staticmethod
    def schema2validator(schema: dict[str, Any]) -> Validator:
        r"""Schema2validator.

        :param schema:
        :type schema: dict[str, Any]
        :rtype: Validator
        """
        return validator_for(schema)(schema)

    def get_diagnostics(self, uri: str, tree: Tree) -> list[Diagnostic]:
        r"""Get diagnostics.

        :param _:
        :type _: str
        :param tree:
        :type tree: Tree
        :rtype: list[Diagnostic]
        """
        trie = self.cls.from_tree(tree)
        return [
            Diagnostic(
                trie.from_path(error.json_path).range,
                error.message,
                DiagnosticSeverity.Error,
            )
            for error in self.validator.iter_errors(trie.to_json())
        ]


@dataclass(init=False)
class QueryFinder(Finder):
    r"""Queryfinder."""

    def __init__(
        self,
        query: Query,
        message: str = "",
        severity: DiagnosticSeverity = DiagnosticSeverity.Error,
    ) -> None:
        r"""Init.

        :param query:
        :type query: Query
        :param message:
        :type message: str
        :param severity:
        :type severity: DiagnosticSeverity
        :rtype: None
        """
        self.cursor = QueryCursor(query)
        super().__init__(message, severity)

    def find_all(
        self, uri: str, tree: Tree | None = None, reset: bool = True
    ) -> list[UNI]:
        r"""Find all.

        :param uri:
        :type uri: str
        :param tree:
        :type tree: Tree | None
        :param reset:
        :type reset: bool
        :rtype: list[UNI]
        """
        tree = self.prepare(uri, tree, reset)
        captures = self.cursor.captures(tree.root_node)
        return self.captures2unis(captures, uri)

    def captures2unis(
        self, captures: dict[str, list[Node]], uri: str
    ) -> list[UNI]:
        r"""Captures2unis.

        :param captures:
        :type captures: dict[str, list[Node]]
        :param uri:
        :type uri: str
        :rtype: list[UNI]
        """
        unis = []
        for label, nodes in captures.items():
            if uni := self.capture2uni(label, nodes, uri):
                unis += [uni]
        return unis

    def capture2uni(
        self, label: str, nodes: list[Node], uri: str
    ) -> UNI | None:
        r"""Capture2uni. can return ``None`` to skip.

        :param label:
        :type label: str
        :param nodes:
        :type nodes: list[Node]
        :param uri:
        :type uri: str
        :rtype: UNI | None
        """
        return UNI(nodes[0], uri)
