use core::fmt;
use core::marker::PhantomData;
use io::Write;
use serde::de::{SeqAccess, Unexpected, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserializer, Serializer};
use super::encode::Error as ConsensusError;
use super::{Decodable, Encodable};
use crate::consensus::{DecodeError, IterReader};
pub struct Hex<Case = hex::Lower>(PhantomData<Case>)
where
Case: hex::Case;
impl<C: hex::Case> Default for Hex<C> {
fn default() -> Self { Hex(Default::default()) }
}
impl<C: hex::Case> ByteEncoder for Hex<C> {
type Encoder = hex::Encoder<C>;
}
pub mod hex {
use core::fmt;
use core::marker::PhantomData;
use hex::buf_encoder::BufEncoder;
pub trait Case: sealed::Case {}
impl<T: sealed::Case> Case for T {}
pub enum Lower {}
pub enum Upper {}
mod sealed {
pub trait Case {
const INTERNAL_CASE: hex::Case;
}
impl Case for super::Lower {
const INTERNAL_CASE: hex::Case = hex::Case::Lower;
}
impl Case for super::Upper {
const INTERNAL_CASE: hex::Case = hex::Case::Upper;
}
}
const HEX_BUF_SIZE: usize = 512;
pub struct Encoder<C: Case>(BufEncoder<{ HEX_BUF_SIZE }>, PhantomData<C>);
impl<C: Case> From<super::Hex<C>> for Encoder<C> {
fn from(_: super::Hex<C>) -> Self { Encoder(BufEncoder::new(), Default::default()) }
}
impl<C: Case> super::EncodeBytes for Encoder<C> {
fn encode_chunk<W: fmt::Write>(&mut self, writer: &mut W, mut bytes: &[u8]) -> fmt::Result {
while !bytes.is_empty() {
if self.0.is_full() {
self.flush(writer)?;
}
bytes = self.0.put_bytes_min(bytes, C::INTERNAL_CASE);
}
Ok(())
}
fn flush<W: fmt::Write>(&mut self, writer: &mut W) -> fmt::Result {
writer.write_str(self.0.as_str())?;
self.0.clear();
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodeInitError(hex::OddLengthStringError);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodeError(hex::InvalidCharError);
pub struct Decoder<'a>(hex::HexSliceToBytesIter<'a>);
impl<'a> Decoder<'a> {
fn new(s: &'a str) -> Result<Self, DecodeInitError> {
match hex::HexToBytesIter::new(s) {
Ok(iter) => Ok(Decoder(iter)),
Err(error) => Err(DecodeInitError(error)),
}
}
}
impl<'a> Iterator for Decoder<'a> {
type Item = Result<u8, DecodeError>;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|result| result.map_err(DecodeError))
}
}
impl<'a, C: Case> super::ByteDecoder<'a> for super::Hex<C> {
type InitError = DecodeInitError;
type DecodeError = DecodeError;
type Decoder = Decoder<'a>;
fn from_str(s: &'a str) -> Result<Self::Decoder, Self::InitError> { Decoder::new(s) }
}
impl super::IntoDeError for DecodeInitError {
fn into_de_error<E: serde::de::Error>(self) -> E {
E::invalid_length(self.0.length(), &"an even number of ASCII-encoded hex digits")
}
}
impl super::IntoDeError for DecodeError {
fn into_de_error<E: serde::de::Error>(self) -> E {
use serde::de::Unexpected;
const EXPECTED_CHAR: &str = "an ASCII-encoded hex digit";
match self.0.invalid_char() {
c if c.is_ascii() => E::invalid_value(Unexpected::Char(c as _), &EXPECTED_CHAR),
c => E::invalid_value(Unexpected::Unsigned(c.into()), &EXPECTED_CHAR),
}
}
}
}
struct DisplayWrapper<'a, T: 'a + Encodable, E>(&'a T, PhantomData<E>);
impl<'a, T: 'a + Encodable, E: ByteEncoder> fmt::Display for DisplayWrapper<'a, T, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut writer = IoWrapper::<'_, _, E::Encoder>::new(f, E::default().into());
self.0.consensus_encode(&mut writer).map_err(|error| {
#[cfg(debug_assertions)]
{
if error.kind() != io::ErrorKind::Other
|| error.get_ref().is_some()
|| !writer.writer.was_error
{
panic!(
"{} returned an unexpected error: {:?}",
core::any::type_name::<T>(),
error
);
}
}
fmt::Error
})?;
let result = writer.actually_flush();
if result.is_err() {
writer.writer.assert_was_error::<E>();
}
result
}
}
struct ErrorTrackingWriter<W: fmt::Write> {
writer: W,
#[cfg(debug_assertions)]
was_error: bool,
}
impl<W: fmt::Write> ErrorTrackingWriter<W> {
fn new(writer: W) -> Self {
ErrorTrackingWriter {
writer,
#[cfg(debug_assertions)]
was_error: false,
}
}
#[track_caller]
fn assert_no_error(&self, fun: &str) {
#[cfg(debug_assertions)]
{
if self.was_error {
panic!("`{}` called on errored writer", fun);
}
}
}
fn assert_was_error<Offender>(&self) {
#[cfg(debug_assertions)]
{
if !self.was_error {
panic!("{} returned an error unexpectedly", core::any::type_name::<Offender>());
}
}
}
fn set_error(&mut self, was: bool) {
#[cfg(debug_assertions)]
{
self.was_error |= was;
}
}
fn check_err<T, E>(&mut self, result: Result<T, E>) -> Result<T, E> {
self.set_error(result.is_err());
result
}
}
impl<W: fmt::Write> fmt::Write for ErrorTrackingWriter<W> {
fn write_str(&mut self, s: &str) -> fmt::Result {
self.assert_no_error("write_str");
let result = self.writer.write_str(s);
self.check_err(result)
}
fn write_char(&mut self, c: char) -> fmt::Result {
self.assert_no_error("write_char");
let result = self.writer.write_char(c);
self.check_err(result)
}
}
struct IoWrapper<'a, W: fmt::Write, E: EncodeBytes> {
writer: ErrorTrackingWriter<&'a mut W>,
encoder: E,
}
impl<'a, W: fmt::Write, E: EncodeBytes> IoWrapper<'a, W, E> {
fn new(writer: &'a mut W, encoder: E) -> Self {
IoWrapper { writer: ErrorTrackingWriter::new(writer), encoder }
}
fn actually_flush(&mut self) -> fmt::Result { self.encoder.flush(&mut self.writer) }
}
impl<'a, W: fmt::Write, E: EncodeBytes> Write for IoWrapper<'a, W, E> {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
match self.encoder.encode_chunk(&mut self.writer, bytes) {
Ok(()) => Ok(bytes.len()),
Err(fmt::Error) => {
self.writer.assert_was_error::<E>();
Err(io::Error::from(io::ErrorKind::Other))
}
}
}
fn flush(&mut self) -> io::Result<()> { Ok(()) }
}
pub trait ByteEncoder: Default {
type Encoder: EncodeBytes + From<Self>;
}
pub trait EncodeBytes {
fn encode_chunk<W: fmt::Write>(&mut self, writer: &mut W, bytes: &[u8]) -> fmt::Result;
fn flush<W: fmt::Write>(&mut self, writer: &mut W) -> fmt::Result;
}
pub trait ByteDecoder<'a> {
type InitError: IntoDeError + fmt::Debug;
type DecodeError: IntoDeError + fmt::Debug;
type Decoder: Iterator<Item = Result<u8, Self::DecodeError>>;
fn from_str(s: &'a str) -> Result<Self::Decoder, Self::InitError>;
}
pub trait IntoDeError {
fn into_de_error<E: serde::de::Error>(self) -> E;
}
struct BinWriter<S: SerializeSeq> {
serializer: S,
error: Option<S::Error>,
}
impl<S: SerializeSeq> Write for BinWriter<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { self.write_all(buf).map(|_| buf.len()) }
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
for byte in buf {
if let Err(error) = self.serializer.serialize_element(byte) {
self.error = Some(error);
return Err(io::ErrorKind::Other.into());
}
}
Ok(())
}
fn flush(&mut self) -> io::Result<()> { Ok(()) }
}
struct DisplayExpected<D: fmt::Display>(D);
impl<D: fmt::Display> serde::de::Expected for DisplayExpected<D> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, formatter)
}
}
fn consensus_error_into_serde<E: serde::de::Error>(error: ConsensusError) -> E {
match error {
ConsensusError::Io(error) => panic!("unexpected IO error {:?}", error),
ConsensusError::OversizedVectorAllocation { requested, max } => E::custom(format_args!(
"the requested allocation of {} items exceeds maximum of {}",
requested, max
)),
ConsensusError::InvalidChecksum { expected, actual } => E::invalid_value(
Unexpected::Bytes(&actual),
&DisplayExpected(format_args!(
"checksum {:02x}{:02x}{:02x}{:02x}",
expected[0], expected[1], expected[2], expected[3]
)),
),
ConsensusError::NonMinimalVarInt =>
E::custom(format_args!("compact size was not encoded minimally")),
ConsensusError::ParseFailed(msg) => E::custom(msg),
ConsensusError::UnsupportedSegwitFlag(flag) =>
E::invalid_value(Unexpected::Unsigned(flag.into()), &"segwit version 1 flag"),
}
}
impl<E> DecodeError<E>
where
E: serde::de::Error,
{
fn unify(self) -> E {
match self {
DecodeError::Other(error) => error,
DecodeError::TooManyBytes => E::custom(format_args!("got more bytes than expected")),
DecodeError::Consensus(error) => consensus_error_into_serde(error),
}
}
}
impl<E> IntoDeError for DecodeError<E>
where
E: IntoDeError,
{
fn into_de_error<DE: serde::de::Error>(self) -> DE {
match self {
DecodeError::Other(error) => error.into_de_error(),
DecodeError::TooManyBytes => DE::custom(format_args!("got more bytes than expected")),
DecodeError::Consensus(error) => consensus_error_into_serde(error),
}
}
}
pub struct With<E>(PhantomData<E>);
impl<E> With<E> {
pub fn serialize<T: Encodable, S: Serializer>(
value: &T,
serializer: S,
) -> Result<S::Ok, S::Error>
where
E: ByteEncoder,
{
if serializer.is_human_readable() {
serializer.collect_str(&DisplayWrapper::<'_, _, E>(value, Default::default()))
} else {
let serializer = serializer.serialize_seq(None)?;
let mut writer = BinWriter { serializer, error: None };
let result = value.consensus_encode(&mut writer);
match (result, writer.error) {
(Ok(_), None) => writer.serializer.end(),
(Ok(_), Some(error)) =>
panic!("{} silently ate an IO error: {:?}", core::any::type_name::<T>(), error),
(Err(io_error), Some(ser_error))
if io_error.kind() == io::ErrorKind::Other && io_error.get_ref().is_none() =>
Err(ser_error),
(Err(io_error), ser_error) => panic!(
"{} returned an unexpected IO error: {:?} serialization error: {:?}",
core::any::type_name::<T>(),
io_error,
ser_error
),
}
}
}
pub fn deserialize<'d, T: Decodable, D: Deserializer<'d>>(
deserializer: D,
) -> Result<T, D::Error>
where
for<'a> E: ByteDecoder<'a>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(HRVisitor::<_, E>(Default::default()))
} else {
deserializer.deserialize_seq(BinVisitor(Default::default()))
}
}
}
struct HRVisitor<T: Decodable, D: for<'a> ByteDecoder<'a>>(PhantomData<fn() -> (T, D)>);
impl<'de, T: Decodable, D: for<'a> ByteDecoder<'a>> Visitor<'de> for HRVisitor<T, D> {
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("bytes encoded as a hex string")
}
fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<T, E> {
let decoder = D::from_str(s).map_err(IntoDeError::into_de_error)?;
IterReader::new(decoder).decode().map_err(IntoDeError::into_de_error)
}
}
struct BinVisitor<T: Decodable>(PhantomData<fn() -> T>);
impl<'de, T: Decodable> Visitor<'de> for BinVisitor<T> {
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a sequence of bytes")
}
fn visit_seq<S: SeqAccess<'de>>(self, s: S) -> Result<T, S::Error> {
IterReader::new(SeqIterator(s, Default::default())).decode().map_err(DecodeError::unify)
}
}
struct SeqIterator<'a, S: serde::de::SeqAccess<'a>>(S, PhantomData<&'a ()>);
impl<'a, S: serde::de::SeqAccess<'a>> Iterator for SeqIterator<'a, S> {
type Item = Result<u8, S::Error>;
fn next(&mut self) -> Option<Self::Item> { self.0.next_element::<u8>().transpose() }
}