use core::cmp::min;
use core::error::Error;
use core::ops::Add;
use bitcoin::block::Header;
use bitcoin::consensus::encode::serialize_hex;
use bitcoin::BlockHash;
use bitcoin::Work;
use floresta_common::prelude::Box;
use floresta_common::prelude::String;
use floresta_common::prelude::Vec;
use crate::BlockchainInterface;
const MEDIAN_TIME_PAST_BLOCK_COUNT: usize = 11;
pub trait HeaderExt {
fn calculate_median_time_past(
&self,
chain: &impl BlockchainInterface,
) -> Result<u32, HeaderExtError>;
fn calculate_chain_work(
&self,
chain: &impl BlockchainInterface,
) -> Result<Work, HeaderExtError>;
fn get_next_block_hash(
&self,
chain: &impl BlockchainInterface,
) -> Result<Option<BlockHash>, HeaderExtError>;
fn get_previous_block_header(
&self,
chain: &impl BlockchainInterface,
) -> Result<Header, HeaderExtError>;
fn get_bits_hex(&self) -> String;
fn get_confirmations(&self, chain: &impl BlockchainInterface) -> Result<u32, HeaderExtError>;
fn get_difficulty(&self) -> f64;
fn get_height(&self, chain: &impl BlockchainInterface) -> Result<u32, HeaderExtError>;
fn get_target_hex(&self) -> String;
fn get_version_hex(&self) -> String;
}
#[derive(Debug)]
pub enum HeaderExtError {
Chain(Box<dyn Error + Send + Sync>),
BlockNotFound,
ChainWorkOverflow,
}
impl HeaderExt for Header {
fn calculate_median_time_past(
&self,
chain: &impl BlockchainInterface,
) -> Result<u32, HeaderExtError> {
let mut block_timestamps = Vec::with_capacity(MEDIAN_TIME_PAST_BLOCK_COUNT);
let mut current_header = *self;
for _ in 0..MEDIAN_TIME_PAST_BLOCK_COUNT {
block_timestamps.push(current_header.time);
let Ok(prev_header) = current_header.get_previous_block_header(chain) else {
break;
};
current_header = prev_header;
}
block_timestamps.sort();
let median_time_past = block_timestamps[block_timestamps.len() / 2];
Ok(median_time_past)
}
fn calculate_chain_work(
&self,
chain: &impl BlockchainInterface,
) -> Result<Work, HeaderExtError> {
let block_height = self.get_height(chain)?;
let mut total_chainwork = Work::from_be_bytes([0u8; 32]);
for epoch_start_height in (0..=block_height).step_by(2016) {
let epoch_end_height = min(epoch_start_height + 2015, block_height);
let blocks_in_epoch = epoch_end_height - epoch_start_height + 1;
let epoch_block_hash = chain
.get_block_hash(epoch_start_height)
.map_err(|e| HeaderExtError::Chain(Box::new(e)))?;
let epoch_block_header = chain
.get_block_header(&epoch_block_hash)
.map_err(|e| HeaderExtError::Chain(Box::new(e)))?;
let epoch_chainwork = epoch_block_header
.work()
.multiply_work_by_u32(blocks_in_epoch)?;
total_chainwork = total_chainwork.add(epoch_chainwork);
}
Ok(total_chainwork)
}
fn get_next_block_hash(
&self,
chain: &impl BlockchainInterface,
) -> Result<Option<BlockHash>, HeaderExtError> {
let height = self.get_height(chain)?;
match chain.get_block_hash(height + 1) {
Ok(opt_hash) => Ok(Some(opt_hash)),
Err(_) => Ok(None),
}
}
fn get_previous_block_header(
&self,
chain: &impl BlockchainInterface,
) -> Result<Header, HeaderExtError> {
let prev_header = chain
.get_block_header(&self.prev_blockhash)
.map_err(|e| HeaderExtError::Chain(Box::new(e)))?;
Ok(prev_header)
}
fn get_bits_hex(&self) -> String {
serialize_hex(&self.bits.to_consensus().to_be())
}
fn get_confirmations(&self, chain: &impl BlockchainInterface) -> Result<u32, HeaderExtError> {
let height = self.get_height(chain)?;
let chain_height = chain
.get_height()
.map_err(|e| HeaderExtError::Chain(Box::new(e)))?;
Ok(chain_height - height + 1)
}
fn get_difficulty(&self) -> f64 {
self.difficulty_float()
}
fn get_height(&self, chain: &impl BlockchainInterface) -> Result<u32, HeaderExtError> {
let height = match chain.get_block_height(&self.block_hash()) {
Ok(Some(height)) => height,
Ok(None) => return Err(HeaderExtError::BlockNotFound),
Err(e) => return Err(HeaderExtError::Chain(Box::new(e))),
};
Ok(height)
}
fn get_target_hex(&self) -> String {
serialize_hex(&self.target().to_be_bytes())
}
fn get_version_hex(&self) -> String {
serialize_hex(&(self.version.to_consensus() as u32).to_be())
}
}
impl From<ChainWorkOverflow> for HeaderExtError {
fn from(_: ChainWorkOverflow) -> Self {
Self::ChainWorkOverflow
}
}
#[derive(Debug, PartialEq)]
pub struct ChainWorkOverflow;
pub trait WorkExt {
fn multiply_work_by_u32(self, factor: u32) -> Result<Work, ChainWorkOverflow>;
fn to_string_hex(&self) -> String;
}
impl WorkExt for Work {
fn multiply_work_by_u32(self, factor: u32) -> Result<Work, ChainWorkOverflow> {
if factor == 0 {
return Ok(Work::from_be_bytes([0u8; 32]));
}
if factor == 1 {
return Ok(self);
}
let work_bytes = self.to_le_bytes();
let mut carry_high: u64 = 0;
let mut result_bytes = [0u8; 32];
let word_size = 4_usize;
let by_chunks: Vec<u32> = work_bytes
.chunks_exact(word_size)
.map(|chunk| {
let mut array = [0u8; 4];
array.copy_from_slice(chunk);
u32::from_le_bytes(array)
})
.collect();
for (word_index, word) in by_chunks.iter().enumerate() {
let product: u64 = (*word as u64) * (factor as u64) + carry_high;
carry_high = product >> 32;
let byte_index = by_chunks.len() - word_index;
result_bytes[(byte_index - 1) * word_size..byte_index * word_size]
.copy_from_slice(&(product as u32).to_be_bytes());
}
if carry_high > 0 {
return Err(ChainWorkOverflow);
}
Ok(Work::from_be_bytes(result_bytes))
}
fn to_string_hex(&self) -> String {
serialize_hex(&self.to_be_bytes())
}
}
#[cfg(test)]
mod tests {
use core::fmt;
use core::fmt::Display;
use core::fmt::Formatter;
use std::collections::HashMap;
use std::sync::Arc;
use bitcoin::block::Header;
use bitcoin::consensus::encode::deserialize_hex;
use bitcoin::hashes::sha256::Hash as Sha256Hash;
use bitcoin::params::Params;
use bitcoin::Block;
use bitcoin::BlockHash;
use bitcoin::OutPoint;
use bitcoin::Transaction;
use bitcoin::Txid;
use rustreexo::proof::Proof;
use rustreexo::stump::Stump;
use super::*;
use crate::BlockConsumer;
use crate::BlockchainError;
use crate::UtxoData;
#[derive(Debug)]
pub enum MockBlockchainError {
NotFound,
}
impl Display for MockBlockchainError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "MockBlockchainError")
}
}
impl core::error::Error for MockBlockchainError {}
pub struct MockBlockchainInterface {
pub headers: HashMap<BlockHash, Header>,
pub heights: HashMap<BlockHash, u32>,
pub chain_height: u32,
}
impl MockBlockchainInterface {
pub fn new() -> Self {
Self {
headers: HashMap::new(),
heights: HashMap::new(),
chain_height: 0,
}
}
pub fn add_block(&mut self, hash: BlockHash, header: Header, height: u32) {
self.headers.insert(hash, header);
self.heights.insert(hash, height);
self.chain_height = self.chain_height.max(height);
}
}
impl BlockchainInterface for MockBlockchainInterface {
type Error = MockBlockchainError;
fn get_block_header(&self, hash: &BlockHash) -> Result<Header, Self::Error> {
self.headers
.get(hash)
.cloned()
.ok_or(MockBlockchainError::NotFound)
}
fn get_block_hash(&self, height: u32) -> Result<BlockHash, Self::Error> {
self.heights
.iter()
.find(|(_, &h)| h == height)
.map(|(hash, _)| *hash)
.ok_or(MockBlockchainError::NotFound)
}
fn get_block_height(&self, hash: &BlockHash) -> Result<Option<u32>, Self::Error> {
Ok(self.heights.get(hash).cloned())
}
fn get_height(&self) -> Result<u32, Self::Error> {
Ok(self.chain_height)
}
fn get_tx(&self, _: &Txid) -> Result<Option<Transaction>, Self::Error> {
unimplemented!()
}
fn estimate_fee(&self, _: usize) -> Result<f64, Self::Error> {
unimplemented!()
}
fn get_block(&self, _: &BlockHash) -> Result<Block, Self::Error> {
unimplemented!()
}
fn get_best_block(&self) -> Result<(u32, BlockHash), Self::Error> {
unimplemented!()
}
fn subscribe(&self, _: Arc<dyn BlockConsumer>) {
unimplemented!()
}
fn is_in_ibd(&self) -> bool {
unimplemented!()
}
fn is_coinbase_mature(&self, _: u32, _: BlockHash) -> Result<bool, Self::Error> {
unimplemented!()
}
fn get_block_locator(&self) -> Result<Vec<BlockHash>, Self::Error> {
unimplemented!()
}
fn get_block_locator_for_tip(
&self,
_: BlockHash,
) -> Result<Vec<BlockHash>, BlockchainError> {
unimplemented!()
}
fn get_validation_index(&self) -> Result<u32, Self::Error> {
unimplemented!()
}
fn update_acc(
&self,
_: Stump,
_: &Block,
_: u32,
_: Proof,
_: Vec<Sha256Hash>,
) -> Result<Stump, Self::Error> {
unimplemented!()
}
fn get_chain_tips(&self) -> Result<Vec<BlockHash>, Self::Error> {
unimplemented!()
}
fn validate_block(
&self,
_: &Block,
_: Proof,
_: HashMap<OutPoint, UtxoData>,
_: Vec<Sha256Hash>,
_: Stump,
) -> Result<(), Self::Error> {
unimplemented!()
}
fn get_fork_point(&self, _: BlockHash) -> Result<BlockHash, Self::Error> {
unimplemented!()
}
fn get_params(&self) -> Params {
unimplemented!()
}
fn acc(&self) -> Stump {
unimplemented!()
}
}
fn get_genesis_header() -> Header {
let genesis_header = "0100000000000000000000000000000000000000000000000000000000000000000000003ba3edfd7a7b12b27ac72c3e67768f617fc81bc3888a51323a9fb8aa4b1e5e4a29ab5f49ffff001d1dac2b7c";
let header: Header = deserialize_hex(genesis_header).expect("Failed to deserialize header");
header
}
fn get_chain_and_headers(height: u32) -> (MockBlockchainInterface, Vec<Header>) {
let mut mock_chain = MockBlockchainInterface::new();
let mut headers = vec![];
let mut prev_blockhash = get_genesis_header().block_hash();
let genesis_header = get_genesis_header();
mock_chain.add_block(prev_blockhash, genesis_header, 0);
headers.push(genesis_header);
for i in 1..height {
let header = Header {
time: 1231006505 + i * 600,
prev_blockhash,
..genesis_header
};
headers.push(header);
let hash = header.block_hash();
mock_chain.add_block(hash, header, i);
prev_blockhash = header.block_hash();
}
(mock_chain, headers)
}
#[test]
fn test_calculate_median_time_past_more_than_11_blocks() {
let (mock_chain, headers) = get_chain_and_headers(21);
let median_header = headers[headers.len() - 1];
let mtp = median_header
.calculate_median_time_past(&mock_chain)
.expect("Failed to calculate MTP");
let mut times = headers
.iter()
.rev()
.take(11)
.map(|h| h.time)
.collect::<Vec<_>>();
times.sort();
let expected_mtp = times[times.len() / 2];
assert_eq!(mtp, expected_mtp);
}
#[test]
fn test_calculate_median_time_past_less_than_11_blocks() {
let (mock_chain, headers) = get_chain_and_headers(7);
let median_header = headers[headers.len() - 1];
let mtp = median_header
.calculate_median_time_past(&mock_chain)
.expect("Failed to calculate MTP");
let mut times = headers.iter().map(|h| h.time).collect::<Vec<_>>();
times.sort();
let expected_mtp = times[times.len() / 2];
assert_eq!(mtp, expected_mtp);
}
#[test]
fn test_calculate_median_time_past_genesis_only() {
let (mock_chain, headers) = get_chain_and_headers(1);
let median_header = headers[0];
let mtp = median_header
.calculate_median_time_past(&mock_chain)
.expect("Failed to calculate MTP");
let expected_mtp = headers[0].time;
assert_eq!(mtp, expected_mtp);
}
#[test]
fn test_calculate_chain_work() {
let (mock_chain, headers) = get_chain_and_headers(3000);
let header = headers[headers.len() - 1];
let work = header
.calculate_chain_work(&mock_chain)
.expect("Failed to calculate chain work");
let expected_hex_string =
"00000000000000000000000000000000000000000000000000000bb80bb80bb8";
let expected_work = Work::from_hex(&format!("0x{expected_hex_string}")).unwrap();
assert_eq!(work.to_string_hex(), expected_hex_string);
assert_eq!(work, expected_work);
}
#[test]
fn test_get_next_block_hash() {
let (mock_chain, headers) = get_chain_and_headers(5);
let header = headers[2];
let next_hash = header
.get_next_block_hash(&mock_chain)
.expect("Failed to get next block hash")
.expect("Next block hash is None");
let expected_hash = headers[3].block_hash();
assert_eq!(next_hash, expected_hash);
let last_header = headers[headers.len() - 1];
let next_hash = last_header
.get_next_block_hash(&mock_chain)
.expect("Failed to get next block hash");
assert!(next_hash.is_none());
}
#[test]
fn test_get_bits() {
let header = get_genesis_header();
let bits_hex = header.get_bits_hex();
assert_eq!(bits_hex, "1d00ffff");
}
#[test]
fn test_get_confirmations() {
let (mock_chain, headers) = get_chain_and_headers(5);
let header = headers[2];
let confirmations = header
.get_confirmations(&mock_chain)
.expect("Failed to get confirmations");
let expected_confirmations = headers.len() - 2;
assert_eq!(confirmations, expected_confirmations as u32);
}
#[test]
fn test_get_difficulty() {
let header = get_genesis_header();
let difficulty = header.get_difficulty();
assert_eq!(difficulty, 1.0);
}
#[test]
fn test_get_height() {
let (mock_chain, headers) = get_chain_and_headers(5);
let height_expected = 3;
let header = headers[height_expected];
let height = header
.get_height(&mock_chain)
.expect("Failed to get block height");
assert_eq!(height, height_expected as u32);
let mut header_missing = headers[0];
header_missing.nonce = 0;
let result = header_missing.get_height(&mock_chain);
assert!(result.is_err());
}
#[test]
fn test_get_target() {
let header = get_genesis_header();
let target_hex = header.get_target_hex();
assert_eq!(
target_hex,
"00000000ffff0000000000000000000000000000000000000000000000000000"
);
}
#[test]
fn test_get_version_hex() {
let header = get_genesis_header();
let version_hex = header.get_version_hex();
assert_eq!(version_hex, "00000001");
}
#[test]
fn test_multiply_work_by_u32_success() {
let work_bytes: [u8; 32] = [
0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
0, 0, 4,
];
let work = Work::from_be_bytes(work_bytes);
let factor = 2;
let result = work.multiply_work_by_u32(factor).unwrap();
let expected_bytes: [u8; 32] = [
0, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0,
0, 0, 8,
];
let expected = Work::from_be_bytes(expected_bytes);
assert_eq!(result, expected);
}
#[test]
fn test_multiply_work_by_u32_overflow() {
let work_bytes: [u8; 32] = [
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
];
let work = Work::from_be_bytes(work_bytes);
let factor = u32::MAX;
let result = work.multiply_work_by_u32(factor);
assert_eq!(result, Err(ChainWorkOverflow));
}
}