use join_to_string::join;

use ra_syntax::{
    algo::{find_covering_node, find_leaf_at_offset},
    ast::{self, AstNode, AttrsOwner, NameOwner, TypeParamsOwner},
    Direction, SourceFileNode,
    SyntaxKind::{COMMA, WHITESPACE, COMMENT},
    SyntaxNodeRef, TextRange, TextUnit,
};

use crate::{find_node_at_offset, TextEdit, TextEditBuilder};

#[derive(Debug)]
pub struct LocalEdit {
    pub edit: TextEdit,
    pub cursor_position: Option<TextUnit>,
}

pub fn flip_comma<'a>(
    file: &'a SourceFileNode,
    offset: TextUnit,
) -> Option<impl FnOnce() -> LocalEdit + 'a> {
    let syntax = file.syntax();

    let comma = find_leaf_at_offset(syntax, offset).find(|leaf| leaf.kind() == COMMA)?;
    let prev = non_trivia_sibling(comma, Direction::Prev)?;
    let next = non_trivia_sibling(comma, Direction::Next)?;
    Some(move || {
        let mut edit = TextEditBuilder::new();
        edit.replace(prev.range(), next.text().to_string());
        edit.replace(next.range(), prev.text().to_string());
        LocalEdit {
            edit: edit.finish(),
            cursor_position: None,
        }
    })
}

pub fn add_derive<'a>(
    file: &'a SourceFileNode,
    offset: TextUnit,
) -> Option<impl FnOnce() -> LocalEdit + 'a> {
    let nominal = find_node_at_offset::<ast::NominalDef>(file.syntax(), offset)?;
    let node_start = derive_insertion_offset(nominal)?;
    return Some(move || {
        let derive_attr = nominal
            .attrs()
            .filter_map(|x| x.as_call())
            .filter(|(name, _arg)| name == "derive")
            .map(|(_name, arg)| arg)
            .next();
        let mut edit = TextEditBuilder::new();
        let offset = match derive_attr {
            None => {
                edit.insert(node_start, "#[derive()]\n".to_string());
                node_start + TextUnit::of_str("#[derive(")
            }
            Some(tt) => tt.syntax().range().end() - TextUnit::of_char(')'),
        };
        LocalEdit {
            edit: edit.finish(),
            cursor_position: Some(offset),
        }
    });

    // Insert `derive` after doc comments.
    fn derive_insertion_offset(nominal: ast::NominalDef) -> Option<TextUnit> {
        let non_ws_child = nominal
            .syntax()
            .children()
            .find(|it| it.kind() != COMMENT && it.kind() != WHITESPACE)?;
        Some(non_ws_child.range().start())
    }
}

pub fn add_impl<'a>(
    file: &'a SourceFileNode,
    offset: TextUnit,
) -> Option<impl FnOnce() -> LocalEdit + 'a> {
    let nominal = find_node_at_offset::<ast::NominalDef>(file.syntax(), offset)?;
    let name = nominal.name()?;

    Some(move || {
        let type_params = nominal.type_param_list();
        let mut edit = TextEditBuilder::new();
        let start_offset = nominal.syntax().range().end();
        let mut buf = String::new();
        buf.push_str("\n\nimpl");
        if let Some(type_params) = type_params {
            type_params.syntax().text().push_to(&mut buf);
        }
        buf.push_str(" ");
        buf.push_str(name.text().as_str());
        if let Some(type_params) = type_params {
            let lifetime_params = type_params
                .lifetime_params()
                .filter_map(|it| it.lifetime())
                .map(|it| it.text());
            let type_params = type_params
                .type_params()
                .filter_map(|it| it.name())
                .map(|it| it.text());
            join(lifetime_params.chain(type_params))
                .surround_with("<", ">")
                .to_buf(&mut buf);
        }
        buf.push_str(" {\n");
        let offset = start_offset + TextUnit::of_str(&buf);
        buf.push_str("\n}");
        edit.insert(start_offset, buf);
        LocalEdit {
            edit: edit.finish(),
            cursor_position: Some(offset),
        }
    })
}

pub fn introduce_variable<'a>(
    file: &'a SourceFileNode,
    range: TextRange,
) -> Option<impl FnOnce() -> LocalEdit + 'a> {
    let node = find_covering_node(file.syntax(), range);
    let expr = node.ancestors().filter_map(ast::Expr::cast).next()?;

    let anchor_stmt = anchor_stmt(expr)?;
    let indent = anchor_stmt.prev_sibling()?;
    if indent.kind() != WHITESPACE {
        return None;
    }
    return Some(move || {
        let mut buf = String::new();
        let mut edit = TextEditBuilder::new();

        buf.push_str("let var_name = ");
        expr.syntax().text().push_to(&mut buf);
        let is_full_stmt = if let Some(expr_stmt) = ast::ExprStmt::cast(anchor_stmt) {
            Some(expr.syntax()) == expr_stmt.expr().map(|e| e.syntax())
        } else {
            false
        };
        if is_full_stmt {
            edit.replace(expr.syntax().range(), buf);
        } else {
            buf.push_str(";");
            indent.text().push_to(&mut buf);
            edit.replace(expr.syntax().range(), "var_name".to_string());
            edit.insert(anchor_stmt.range().start(), buf);
        }
        let cursor_position = anchor_stmt.range().start() + TextUnit::of_str("let ");
        LocalEdit {
            edit: edit.finish(),
            cursor_position: Some(cursor_position),
        }
    });

    /// Statement or last in the block expression, which will follow
    /// the freshly introduced var.
    fn anchor_stmt(expr: ast::Expr) -> Option<SyntaxNodeRef> {
        expr.syntax().ancestors().find(|&node| {
            if ast::Stmt::cast(node).is_some() {
                return true;
            }
            if let Some(expr) = node
                .parent()
                .and_then(ast::Block::cast)
                .and_then(|it| it.expr())
            {
                if expr.syntax() == node {
                    return true;
                }
            }
            false
        })
    }
}

fn non_trivia_sibling(node: SyntaxNodeRef, direction: Direction) -> Option<SyntaxNodeRef> {
    node.siblings(direction)
        .skip(1)
        .find(|node| !node.kind().is_trivia())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_utils::{check_action, check_action_range};

    #[test]
    fn test_swap_comma() {
        check_action(
            "fn foo(x: i32,<|> y: Result<(), ()>) {}",
            "fn foo(y: Result<(), ()>,<|> x: i32) {}",
            |file, off| flip_comma(file, off).map(|f| f()),
        )
    }

    #[test]
    fn add_derive_new() {
        check_action(
            "struct Foo { a: i32, <|>}",
            "#[derive(<|>)]\nstruct Foo { a: i32, }",
            |file, off| add_derive(file, off).map(|f| f()),
        );
        check_action(
            "struct Foo { <|> a: i32, }",
            "#[derive(<|>)]\nstruct Foo {  a: i32, }",
            |file, off| add_derive(file, off).map(|f| f()),
        );
    }

    #[test]
    fn add_derive_existing() {
        check_action(
            "#[derive(Clone)]\nstruct Foo { a: i32<|>, }",
            "#[derive(Clone<|>)]\nstruct Foo { a: i32, }",
            |file, off| add_derive(file, off).map(|f| f()),
        );
    }

    #[test]
    fn add_derive_new_with_doc_comment() {
        check_action(
            "
/// `Foo` is a pretty important struct.
/// It does stuff.
struct Foo { a: i32<|>, }
            ",
            "
/// `Foo` is a pretty important struct.
/// It does stuff.
#[derive(<|>)]
struct Foo { a: i32, }
            ",
            |file, off| add_derive(file, off).map(|f| f()),
        );
    }

    #[test]
    fn test_add_impl() {
        check_action(
            "struct Foo {<|>}\n",
            "struct Foo {}\n\nimpl Foo {\n<|>\n}\n",
            |file, off| add_impl(file, off).map(|f| f()),
        );
        check_action(
            "struct Foo<T: Clone> {<|>}",
            "struct Foo<T: Clone> {}\n\nimpl<T: Clone> Foo<T> {\n<|>\n}",
            |file, off| add_impl(file, off).map(|f| f()),
        );
        check_action(
            "struct Foo<'a, T: Foo<'a>> {<|>}",
            "struct Foo<'a, T: Foo<'a>> {}\n\nimpl<'a, T: Foo<'a>> Foo<'a, T> {\n<|>\n}",
            |file, off| add_impl(file, off).map(|f| f()),
        );
    }

    #[test]
    fn test_introduce_var_simple() {
        check_action_range(
            "
fn foo() {
    foo(<|>1 + 1<|>);
}",
            "
fn foo() {
    let <|>var_name = 1 + 1;
    foo(var_name);
}",
            |file, range| introduce_variable(file, range).map(|f| f()),
        );
    }

    #[test]
    fn test_introduce_var_expr_stmt() {
        check_action_range(
            "
fn foo() {
    <|>1 + 1<|>;
}",
            "
fn foo() {
    let <|>var_name = 1 + 1;
}",
            |file, range| introduce_variable(file, range).map(|f| f()),
        );
    }

    #[test]
    fn test_introduce_var_part_of_expr_stmt() {
        check_action_range(
            "
fn foo() {
    <|>1<|> + 1;
}",
            "
fn foo() {
    let <|>var_name = 1;
    var_name + 1;
}",
            |file, range| introduce_variable(file, range).map(|f| f()),
        );
    }

    #[test]
    fn test_introduce_var_last_expr() {
        check_action_range(
            "
fn foo() {
    bar(<|>1 + 1<|>)
}",
            "
fn foo() {
    let <|>var_name = 1 + 1;
    bar(var_name)
}",
            |file, range| introduce_variable(file, range).map(|f| f()),
        );
    }

    #[test]
    fn test_introduce_var_last_full_expr() {
        check_action_range(
            "
fn foo() {
    <|>bar(1 + 1)<|>
}",
            "
fn foo() {
    let <|>var_name = bar(1 + 1);
    var_name
}",
            |file, range| introduce_variable(file, range).map(|f| f()),
        );
    }

}