use tokenizers::{
    NormalizedString, Normalizer, PreTokenizedString, PreTokenizer, normalizers::NFC,
    pre_tokenizers::bert::BertPreTokenizer,
};

use crate::{ChunkError, types::Sizer};

pub struct ByteSizer {
    max_chunk_size: usize,
}

impl ByteSizer {
    pub fn new(max_chunk_size: usize) -> Self {
        Self { max_chunk_size }
    }

    /// Find the nearest character boundary at or before the given byte index
    fn find_char_boundary_before(&self, s: &str, mut byte_index: usize) -> usize {
        while byte_index > 0 && !s.is_char_boundary(byte_index) {
            byte_index -= 1;
        }
        byte_index
    }
}

impl Sizer for ByteSizer {
    fn find_end_byte(&self, content: &str, start_byte: usize) -> Result<(usize, bool), ChunkError> {
        let remaining_length = content.len() - start_byte;

        if remaining_length <= self.max_chunk_size {
            Ok((content.len(), false))
        } else {
            let mut end_byte = start_byte + self.max_chunk_size;

            end_byte = self.find_char_boundary_before(content, end_byte);

            Ok((end_byte, true))
        }
    }
}

/// PretokenizerSizer finds the size of a chunk based on tokenization. The tokenizer should be setup with
/// truncation to the required chunk size limit.
#[derive(Clone, Debug)]
pub struct PretokenizerSizer<N, PT> {
    normalizer: N,
    pre_tokenizer: PT,
    token_limit: usize,
}

impl<N: Normalizer, PT: PreTokenizer> PretokenizerSizer<N, PT> {
    pub fn new(normalizer: N, pre_tokenizer: PT, token_limit: usize) -> Self {
        Self {
            normalizer,
            pre_tokenizer,
            token_limit,
        }
    }
}

impl PretokenizerSizer<NFC, BertPreTokenizer> {
    pub fn with_bert_pre(token_limit: usize) -> Result<Self, ChunkError> {
        let normalizer = NFC {};
        let pre_tokenizer = BertPreTokenizer {};

        Ok(PretokenizerSizer::new(
            normalizer,
            pre_tokenizer,
            token_limit,
        ))
    }
}

impl<N: Normalizer, PT: PreTokenizer> Sizer for PretokenizerSizer<N, PT> {
    fn find_end_byte(&self, content: &str, start_byte: usize) -> Result<(usize, bool), ChunkError> {
        let mut normalized = NormalizedString::from(&content[start_byte..]);
        if let Err(e) = self.normalizer.normalize(&mut normalized) {
            return Err(ChunkError::TokenizationError(e.to_string()));
        }

        let mut pretokenized = PreTokenizedString::from(normalized);
        if let Err(e) = self.pre_tokenizer.pre_tokenize(&mut pretokenized) {
            return Err(ChunkError::TokenizationError(e.to_string()));
        }

        let splits = pretokenized.get_splits(
            tokenizers::OffsetReferential::Original,
            tokenizers::OffsetType::Byte,
        );
        if splits.is_empty() {
            // Failed to find any tokens, probably due to content being empty/whitespace.
            return Ok((content.len(), false));
        }

        let limit_idx = self.token_limit.min(splits.len()) - 1;

        let (_str, (_token_start, token_end), _tokens) = splits[limit_idx];

        if limit_idx >= splits.len() - 1 {
            // There aren't any tokens after this, so consume the rest of the content.
            return Ok((content.len(), false));
        }

        let end_byte = start_byte + token_end;

        Ok((end_byte, end_byte < content.len()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_byte_sizer() {
        let content = "Hello тест мир! This is a тест.";
        let max_chunk_size = 20;
        let sizer = ByteSizer { max_chunk_size };

        let (end_byte, more) = sizer.find_end_byte(content, 0).unwrap();

        assert_eq!(end_byte, 19);
        assert!(more);

        let (end_byte, more) = sizer.find_end_byte(content, end_byte).unwrap();

        assert_eq!(end_byte, 39);
        assert!(more);

        let (end_byte, more) = sizer.find_end_byte(content, end_byte).unwrap();

        assert_eq!(end_byte, content.len());
        assert!(!more);
    }

    #[test]
    fn test_pre_tokenization_sizer() {
        let sizer = PretokenizerSizer::with_bert_pre(2).unwrap();
        let content = "apple banana carrot   ";

        let (end_byte, more) = sizer.find_end_byte(content, 0).unwrap();

        assert_eq!(end_byte, 12);
        assert!(more);

        let (end_byte, more) = sizer.find_end_byte(content, end_byte).unwrap();

        assert_eq!(end_byte, content.len());
        assert!(!more);
    }
}
