use std::{
    any::Any,
    future::Future,
    io,
    net::IpAddr,
    sync::{Arc, Mutex as StdMutex},
};
use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap};
use futures::{FutureExt, SinkExt};
use log::{debug, error};
use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use tokio::{
    select,
    sync::{broadcast, Mutex},
    task::JoinHandle,
};
use tokio_util::codec::FramedWrite;
use crate::{
    serialization::{CallbackFramed, MsgEncoder},
    tls::{server_config, CertResolver},
    BlockAddr, CallMsg, DeserCallback, Result, Transmitter,
};
macro_rules! handle_err {
    ($result:expr, $on_err:expr, $control_flow:expr) => {
        match $result {
            Ok(inner) => inner,
            Err(err) => {
                $on_err(err);
                $control_flow;
            }
        }
    };
}
macro_rules! unwrap_or_return {
    ($result:expr, $on_err:expr) => {
        handle_err!($result, $on_err, return)
    };
    ($result:expr) => {
        unwrap_or_return!($result, |err| error!("{err}"))
    };
}
macro_rules! unwrap_or_continue {
    ($result:expr, $on_err:expr) => {
        handle_err!($result, $on_err, continue)
    };
    ($result:expr) => {
        unwrap_or_continue!($result, |err| error!("{err}"))
    };
}
macro_rules! await_or_stop {
    ($future:expr, $stop_fut:expr) => {
        select! {
            Some(connecting) = $future => connecting,
            _ = $stop_fut => break,
        }
    };
}
pub struct Receiver {
    recv_addr: Arc<BlockAddr>,
    stop_tx: broadcast::Sender<()>,
    endpoint: Endpoint,
    resolver: Arc<CertResolver>,
    join_handle: StdMutex<Option<JoinHandle<()>>>,
}
impl Receiver {
    pub fn new<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
        ip_addr: IpAddr,
        creds: Arc<C>,
        callback: F,
    ) -> Result<Receiver> {
        let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
        let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
        log::info!("starting Receiver with address {}", recv_addr);
        let socket_addr = recv_addr.socket_addr()?;
        let resolver = Arc::new(CertResolver::new(creds)?);
        let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
        let (stop_tx, stop_rx) = broadcast::channel(1);
        let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
        Ok(Self {
            recv_addr,
            stop_tx,
            endpoint,
            resolver,
            join_handle: StdMutex::new(Some(join_handle)),
        })
    }
    async fn server_loop<F: 'static + MsgCallback>(
        endpoint: Endpoint,
        callback: F,
        mut stop_rx: broadcast::Receiver<()>,
    ) {
        loop {
            let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
            let connection = unwrap_or_continue!(connecting.await, |err| error!(
                "error accepting QUIC connection: {err}"
            ));
            tokio::spawn(Self::handle_connection(
                connection,
                callback.clone(),
                stop_rx.resubscribe(),
            ));
        }
    }
    async fn handle_connection<F: 'static + MsgCallback>(
        connection: Connection,
        callback: F,
        mut stop_rx: broadcast::Receiver<()>,
    ) {
        let client_path = unwrap_or_return!(
            Self::client_path(connection.peer_identity()),
            |err| error!("failed to get client path from peer identity: {err}")
        );
        loop {
            let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
            let (send_stream, recv_stream) = match result {
                Ok(pair) => pair,
                Err(err) => match err {
                    ConnectionError::ApplicationClosed(app) => {
                        debug!("connection closed: {app}");
                        return;
                    }
                    _ => {
                        error!("error accepting stream: {err}");
                        continue;
                    }
                },
            };
            let client_path = client_path.clone();
            let callback = callback.clone();
            tokio::task::spawn(Self::handle_message(
                client_path,
                send_stream,
                recv_stream,
                callback,
            ));
        }
    }
    async fn handle_message<F: 'static + MsgCallback>(
        client_path: Arc<BlockPath>,
        send_stream: SendStream,
        recv_stream: RecvStream,
        callback: F,
    ) {
        let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
        let callback =
            MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
        let mut msg_stream = CallbackFramed::new(recv_stream);
        let result = msg_stream
            .next(callback)
            .await
            .ok_or_else(|| bterr!("client closed stream before sending a message"));
        match unwrap_or_return!(result) {
            Err(err) => error!("msg_stream produced an error: {err}"),
            Ok(result) => {
                if let Err(err) = result {
                    error!("callback returned an error: {err}");
                }
            }
        }
    }
    fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
        let peer_identity =
            peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
        let client_certs = peer_identity
            .downcast::<Vec<rustls::Certificate>>()
            .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
        let first = client_certs
            .first()
            .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
        let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
        Ok(Arc::new(writecap.bind_path()))
    }
    pub fn addr(&self) -> &Arc<BlockAddr> {
        &self.recv_addr
    }
    pub async fn transmitter(&self, addr: Arc<BlockAddr>) -> Result<Transmitter> {
        Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
    }
    pub fn complete(&self) -> Result<JoinHandle<()>> {
        let mut guard = self.join_handle.lock().display_err()?;
        let handle = guard
            .take()
            .ok_or_else(|| bterr!("join handle has already been taken"))?;
        Ok(handle)
    }
    pub fn stop(&self) -> Result<()> {
        self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())
    }
}
impl Drop for Receiver {
    fn drop(&mut self) {
        let _ = self.stop_tx.send(());
    }
}
pub trait MsgCallback: Clone + Send + Sync + Unpin {
    type Arg<'de>: CallMsg<'de>
    where
        Self: 'de;
    type CallFut<'de>: Future<Output = Result<()>> + Send
    where
        Self: 'de;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
}
impl<T: MsgCallback> MsgCallback for &T {
    type Arg<'de> = T::Arg<'de> where Self: 'de;
    type CallFut<'de> = T::CallFut<'de> where Self: 'de;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
        (*self).call(arg)
    }
}
struct MsgRecvdCallback<F> {
    path: Arc<BlockPath>,
    replier: Replier,
    inner: F,
}
impl<F: MsgCallback> MsgRecvdCallback<F> {
    fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
        Self {
            path,
            replier: Replier::new(framed_msg),
            inner,
        }
    }
}
impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
    type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
    type Return = Result<()>;
    type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
    fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
        let replier = match arg.kind {
            MsgKind::Call => Some(self.replier.clone()),
            MsgKind::Send => None,
        };
        async move {
            let result = self
                .inner
                .call(MsgReceived::new(self.path.clone(), arg, replier))
                .await;
            match result {
                Ok(value) => Ok(value),
                Err(err) => match err.downcast::<io::Error>() {
                    Ok(err) => {
                        self.replier
                            .reply_err(err.to_string(), err.raw_os_error())
                            .await
                    }
                    Err(err) => self.replier.reply_err(err.to_string(), None).await,
                },
            }
        }
    }
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
enum MsgKind {
    Call,
    Send,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) struct Envelope<T> {
    kind: MsgKind,
    msg: T,
}
impl<T> Envelope<T> {
    pub(crate) fn send(msg: T) -> Self {
        Self {
            msg,
            kind: MsgKind::Send,
        }
    }
    pub(crate) fn call(msg: T) -> Self {
        Self {
            msg,
            kind: MsgKind::Call,
        }
    }
    fn msg(&self) -> &T {
        &self.msg
    }
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) enum ReplyEnvelope<T> {
    Ok(T),
    Err {
        message: String,
        os_code: Option<i32>,
    },
}
impl<T> ReplyEnvelope<T> {
    fn err(message: String, os_code: Option<i32>) -> Self {
        Self::Err { message, os_code }
    }
}
pub struct MsgReceived<T> {
    from: Arc<BlockPath>,
    msg: Envelope<T>,
    replier: Option<Replier>,
}
impl<T> MsgReceived<T> {
    fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
        Self { from, msg, replier }
    }
    pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
        (self.from, self.msg.msg, self.replier)
    }
    pub fn from(&self) -> &Arc<BlockPath> {
        &self.from
    }
    pub fn body(&self) -> &T {
        self.msg.msg()
    }
    pub fn needs_reply(&self) -> bool {
        self.replier.is_some()
    }
    pub fn take_replier(&mut self) -> Option<Replier> {
        self.replier.take()
    }
}
type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
type ArcMutex<T> = Arc<Mutex<T>>;
#[derive(Clone)]
pub struct Replier {
    stream: ArcMutex<FramedMsg>,
}
impl Replier {
    fn new(stream: ArcMutex<FramedMsg>) -> Self {
        Self { stream }
    }
    pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
        let mut guard = self.stream.lock().await;
        guard.send(ReplyEnvelope::Ok(reply)).await?;
        Ok(())
    }
    pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
        let mut guard = self.stream.lock().await;
        guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
        Ok(())
    }
}