diff --git a/src/structs.rs b/src/structs.rs index a5dbb02..68a5800 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -1,5 +1,8 @@ -use serde::Deserialize; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer}; +use std::fmt; use std::io; +use std::marker::PhantomData; use std::process::ExitStatus; use std::string::FromUtf8Error; use tokio::task::JoinError; @@ -27,6 +30,7 @@ pub struct VideoData { pub requested_formats: Vec, #[serde(default)] pub url: Option, + #[serde(deserialize_with = "convert_to_u64")] pub duration: u64, pub id: String, pub title: String, @@ -96,3 +100,120 @@ impl From for Error { Error::FromUtf8Error(error) } } + +fn convert_to_u64<'de, D>(deserializer: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + struct ConvertToU64(PhantomData T>); + + impl<'de> Visitor<'de> for ConvertToU64 { + type Value = u64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an integer between 0 and 2^63") + } + + fn visit_i8(self, value: i8) -> std::result::Result + where + E: de::Error, + { + use std::u64; + if value >= 0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + + fn visit_i16(self, value: i16) -> std::result::Result + where + E: de::Error, + { + use std::u64; + if value >= 0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + + fn visit_i32(self, value: i32) -> std::result::Result + where + E: de::Error, + { + use std::u64; + if value >= 0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + + fn visit_i64(self, value: i64) -> std::result::Result + where + E: de::Error, + { + use std::u64; + if value >= 0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + + fn visit_u8(self, value: u8) -> std::result::Result + where + E: de::Error, + { + Ok(u64::from(value)) + } + + fn visit_u16(self, value: u16) -> std::result::Result + where + E: de::Error, + { + Ok(u64::from(value)) + } + + fn visit_u32(self, value: u32) -> std::result::Result + where + E: de::Error, + { + Ok(u64::from(value)) + } + + fn visit_u64(self, value: u64) -> std::result::Result + where + E: de::Error, + { + Ok(value) + } + + fn visit_f32(self, value: f32) -> std::result::Result + where + E: de::Error, + { + let value = value.ceil(); + if value >= 0.0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + + fn visit_f64(self, value: f64) -> std::result::Result + where + E: de::Error, + { + let value = value.ceil(); + if value >= 0.0 { + Ok(value as u64) + } else { + Err(E::custom(format!("u64 out of range: {}", value))) + } + } + } + + deserializer.deserialize_any(ConvertToU64(PhantomData)) +}