From 8413555d2fe25b12a30cecb3b73caa0f0a6ab862 Mon Sep 17 00:00:00 2001 From: jannikac Date: Sat, 11 Feb 2023 04:08:54 +0100 Subject: [PATCH] Implement thiserror and improve output formatting (#91) * implemented an error struct with all possible errors using thiserror * Replaced die_with with ClientError enum derived with thiserror. This enables prettier and more structured error handling than anyhow * Added proper parsing of Gandi API Responses. This also makes it possible to output prettier logs. * improved error message * anyhow is better for main fn because it outputs anyhow errors correctly --- Cargo.lock | 1 + Cargo.toml | 1 + src/config.rs | 30 +++++-- src/ip_source/icanhazip.rs | 8 +- src/ip_source/ip_source.rs | 6 +- src/ip_source/ipify.rs | 8 +- src/ip_source/seeip.rs | 8 +- src/main.rs | 176 ++++++++++++++++++++++++++++++------- 8 files changed, 189 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 359c276..3eab244 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -666,6 +666,7 @@ dependencies = [ "regex", "reqwest", "serde", + "thiserror", "tokio", "toml", ] diff --git a/Cargo.toml b/Cargo.toml index 7c894be..9351ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ anyhow = "1.0" governor = "0.5" async-trait = "0.1" die-exit = "0.4" +thiserror = "1.0.38" [dev-dependencies] httpmock = "0.6" diff --git a/src/config.rs b/src/config.rs index 2cfbc20..e56eef0 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,26 @@ use crate::opts; use directories::ProjectDirs; use serde::Deserialize; -use std::fs; use std::path::PathBuf; +use std::{fs, io}; +use thiserror::Error; fn default_types() -> Vec { DEFAULT_TYPES.iter().map(|v| v.to_string()).collect() } +#[derive(Error, Debug)] +pub enum ConfigError { + #[error("Failed to read config file: {0} ")] + Io(#[from] io::Error), + #[error("Failed to parse config file: {0}")] + Parse(#[from] toml::de::Error), + #[error("Entry '{0}' has invalid type '{1}'")] + Validation(String, String), + #[error("Can't find config directory")] + ConfigNotFound(), +} + #[derive(Deserialize, Debug)] pub struct Entry { pub name: String, @@ -65,18 +78,20 @@ impl Config { } } -fn load_config_from>(path: P) -> anyhow::Result { +fn load_config_from>( + path: P, +) -> Result { let contents = fs::read_to_string(path)?; Ok(toml::from_str(&contents)?) } -pub fn load_config(opts: &opts::Opts) -> anyhow::Result { +pub fn load_config(opts: &opts::Opts) -> Result { let mut config = match &opts.config { Some(config_path) => load_config_from(config_path), None => { let confpath = ProjectDirs::from("me", "kaangenc", "gandi-dynamic-dns") .map(|dir| PathBuf::from(dir.config_dir()).join("config.toml")) - .ok_or_else(|| anyhow::anyhow!("Can't find config directory")); + .ok_or(ConfigError::ConfigNotFound()); confpath .and_then(|path| { println!("Checking for config: {}", path.to_string_lossy()); @@ -105,11 +120,14 @@ pub fn load_config(opts: &opts::Opts) -> anyhow::Result { Ok(config) } -pub fn validate_config(config: &Config) -> anyhow::Result<()> { +pub fn validate_config(config: &Config) -> Result<(), ConfigError> { for entry in &config.entry { for entry_type in Config::types(entry) { if entry_type != "A" && entry_type != "AAAA" { - anyhow::bail!("Entry {} has invalid type {}", entry.name, entry_type); + return Err(ConfigError::Validation( + entry.name.clone(), + entry_type.to_string(), + )); } } } diff --git a/src/ip_source/icanhazip.rs b/src/ip_source/icanhazip.rs index 77394c3..9d26aac 100644 --- a/src/ip_source/icanhazip.rs +++ b/src/ip_source/icanhazip.rs @@ -1,10 +1,12 @@ use async_trait::async_trait; +use crate::ClientError; + use super::ip_source::IPSource; pub(crate) struct IPSourceIcanhazip; -async fn get_ip(api_url: &str) -> anyhow::Result { +async fn get_ip(api_url: &str) -> Result { let response = reqwest::get(api_url).await?; let text = response.text().await?; Ok(text) @@ -12,14 +14,14 @@ async fn get_ip(api_url: &str) -> anyhow::Result { #[async_trait] impl IPSource for IPSourceIcanhazip { - async fn get_ipv4(&self) -> anyhow::Result { + async fn get_ipv4(&self) -> Result { Ok(get_ip("https://ipv4.icanhazip.com") .await? // icanazip puts a newline at the end .trim() .to_string()) } - async fn get_ipv6(&self) -> anyhow::Result { + async fn get_ipv6(&self) -> Result { Ok(get_ip("https://ipv6.icanhazip.com") .await? // icanazip puts a newline at the end diff --git a/src/ip_source/ip_source.rs b/src/ip_source/ip_source.rs index e3e49fc..60d6c96 100644 --- a/src/ip_source/ip_source.rs +++ b/src/ip_source/ip_source.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; +use crate::ClientError; + #[async_trait] pub trait IPSource { - async fn get_ipv4(&self) -> anyhow::Result; - async fn get_ipv6(&self) -> anyhow::Result; + async fn get_ipv4(&self) -> Result; + async fn get_ipv6(&self) -> Result; } diff --git a/src/ip_source/ipify.rs b/src/ip_source/ipify.rs index 19fb8eb..2a0eb96 100644 --- a/src/ip_source/ipify.rs +++ b/src/ip_source/ipify.rs @@ -1,10 +1,12 @@ use async_trait::async_trait; +use crate::ClientError; + use super::ip_source::IPSource; pub(crate) struct IPSourceIpify; -async fn get_ip(api_url: &str) -> anyhow::Result { +async fn get_ip(api_url: &str) -> Result { let response = reqwest::get(api_url).await?; let text = response.text().await?; Ok(text) @@ -12,10 +14,10 @@ async fn get_ip(api_url: &str) -> anyhow::Result { #[async_trait] impl IPSource for IPSourceIpify { - async fn get_ipv4(&self) -> anyhow::Result { + async fn get_ipv4(&self) -> Result { get_ip("https://api.ipify.org").await } - async fn get_ipv6(&self) -> anyhow::Result { + async fn get_ipv6(&self) -> Result { get_ip("https://api6.ipify.org").await } } diff --git a/src/ip_source/seeip.rs b/src/ip_source/seeip.rs index 9945e96..d87c526 100644 --- a/src/ip_source/seeip.rs +++ b/src/ip_source/seeip.rs @@ -1,10 +1,12 @@ use async_trait::async_trait; +use crate::ClientError; + use super::ip_source::IPSource; pub(crate) struct IPSourceSeeIP; -async fn get_ip(api_url: &str) -> anyhow::Result { +async fn get_ip(api_url: &str) -> Result { let response = reqwest::get(api_url).await?; let text = response.text().await?; Ok(text) @@ -12,10 +14,10 @@ async fn get_ip(api_url: &str) -> anyhow::Result { #[async_trait] impl IPSource for IPSourceSeeIP { - async fn get_ipv4(&self) -> anyhow::Result { + async fn get_ipv4(&self) -> Result { get_ip("https://ip4.seeip.org").await } - async fn get_ipv6(&self) -> anyhow::Result { + async fn get_ipv6(&self) -> Result { get_ip("https://ip6.seeip.org").await } } diff --git a/src/main.rs b/src/main.rs index 4ff6e08..dcac06f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,13 @@ use crate::config::Config; use crate::gandi::GandiAPI; use crate::ip_source::{ip_source::IPSource, ipify::IPSourceIpify}; use clap::Parser; -use config::IPSourceName; +use config::{ConfigError, IPSourceName}; use ip_source::icanhazip::IPSourceIcanhazip; use ip_source::seeip::IPSourceSeeIP; use opts::Opts; +use reqwest::header::InvalidHeaderValue; use reqwest::{header, Client, ClientBuilder, StatusCode}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::{num::NonZeroU32, sync::Arc, time::Duration}; use tokio::join; use tokio::{self, task::JoinHandle, time::sleep}; @@ -16,13 +17,44 @@ mod gandi; mod ip_source; mod opts; use die_exit::*; +use thiserror::Error; /// 30 requests per minute, see https://api.gandi.net/docs/reference/ const GANDI_RATE_LIMIT: u32 = 30; /// If we hit the rate limit, wait up to this many seconds before next attempt const GANDI_DELAY_JITTER: u64 = 20; -fn api_client(api_key: &str) -> anyhow::Result { +#[derive(Error, Debug)] +pub enum ClientError { + #[error("Error occured while reading config: {0}")] + Config(#[from] ConfigError), + #[error("Error while accessing the Gandi API: {0}")] + Api(#[from] ApiError), + #[error("Error while converting the API key to a header: {0}")] + InvalidHeader(#[from] InvalidHeaderValue), + #[error("Error while sending request: {0}")] + Request(#[from] reqwest::Error), + #[error("Error while joining async tasks: {0}")] + TaskJoin(#[from] tokio::task::JoinError), + #[error("Unexpected type in config: {0}")] + BadEntry(String), + #[error("Entry '{0}' includes type A which requires an IPv4 adress but no IPv4 adress could be determined because: {1}")] + Ipv4missing(String, String), + #[error("Entry '{0}' includes type AAAA which requires an IPv6 adress but no IPv6 adress could be determined because: {1}")] + Ipv6missing(String, String), +} + +#[derive(Error, Debug)] +pub enum ApiError { + #[error("API returned 403 - Forbidden. Message: {message:?}")] + Forbidden { message: String }, + #[error("API returned 403 - Unauthorized. Provided API key is possibly incorrect")] + Unauthorized(), + #[error("API returned {0} - {0}")] + Unknown(StatusCode, String), +} + +fn api_client(api_key: &str) -> Result { let client_builder = ClientBuilder::new(); let key = format!("Apikey {}", api_key); @@ -42,12 +74,27 @@ pub struct APIPayload { pub rrset_ttl: u32, } +#[derive(Debug)] +struct ResponseFeedback { + entry_name: String, + entry_type: String, + response: Result, +} + +#[derive(Deserialize)] +struct ApiResponse { + message: String, + cause: Option, + code: Option, + object: Option, +} + async fn run( base_url: &str, ip_source: &Box, conf: &Config, opts: &Opts, -) -> anyhow::Result<()> { +) -> Result<(), ClientError> { let mut last_ipv4: Option = None; let mut last_ipv6: Option = None; @@ -80,7 +127,7 @@ async fn run( if !ipv4_same || !ipv6_same || conf.always_update { let client = api_client(&conf.api_key)?; - let mut tasks: Vec> = Vec::new(); + let mut tasks: Vec>> = Vec::new(); println!("Attempting to update DNS entries now"); let governor = Arc::new(governor::RateLimiter::direct(governor::Quota::per_minute( @@ -100,10 +147,22 @@ async fn run( } .url(); let ip = match entry_type { - "A" => ipv4.die_with(|error| format!("Needed IPv4 for {fqdn}: {error}")), - "AAAA" => ipv6.die_with(|error| format!("Needed IPv6 for {fqdn}: {error}")), - bad_entry_type => die!("Unexpected type in config: {}", bad_entry_type), - }; + "A" => match ipv4 { + Ok(ref value) => Ok(value), + Err(ref err) => Err(ClientError::Ipv4missing( + entry.name.clone(), + err.to_string(), + )), + }, + "AAAA" => match ipv6 { + Ok(ref value) => Ok(value), + Err(ref err) => Err(ClientError::Ipv6missing( + entry.name.clone(), + err.to_string(), + )), + }, + &_ => Err(ClientError::BadEntry(entry_type.to_string())), + }?; let payload = APIPayload { rrset_values: vec![ip.to_string()], rrset_ttl: Config::ttl(entry, conf), @@ -111,28 +170,82 @@ async fn run( let req = client.put(url).json(&payload); let task_governor = governor.clone(); let entry_type = entry_type.to_string(); - let task = tokio::task::spawn(async move { - task_governor.until_ready_with_jitter(retry_jitter).await; - println!("Updating {} record for {}", entry_type, &fqdn); - match req.send().await { - Ok(response) => ( - response.status(), - response - .text() - .await - .unwrap_or_else(|error| error.to_string()), - ), - Err(error) => (StatusCode::IM_A_TEAPOT, error.to_string()), - } - }); + let entry_name = entry.name.to_string(); + + let task: JoinHandle> = + tokio::task::spawn(async move { + task_governor.until_ready_with_jitter(retry_jitter).await; + println!("Updating {} record for {}", entry_type, &fqdn); + + let resp = req.send().await?; + + let response_feedback = match resp.status() { + StatusCode::CREATED => { + let body: ApiResponse = resp.json().await?; + ResponseFeedback { + entry_name, + entry_type, + response: Ok(body.message), + } + } + StatusCode::UNAUTHORIZED => ResponseFeedback { + entry_name, + entry_type, + response: Err(ApiError::Unauthorized()), + }, + StatusCode::FORBIDDEN => { + let body: ApiResponse = resp.json().await?; + ResponseFeedback { + entry_name: entry_name.clone(), + entry_type, + response: Err(ApiError::Forbidden { + message: body.message, + }), + } + } + _ => { + let status = resp.status(); + let body: ApiResponse = resp.json().await?; + ResponseFeedback { + entry_name, + entry_type, + response: Err(ApiError::Unknown(status, body.message)), + } + } + }; + Ok(response_feedback) + }); tasks.push(task); } } let results = futures::future::try_join_all(tasks).await?; - println!("Updates done for {} entries", results.len()); - for (status, body) in results { - println!("{} - {}", status, body); + // Only count successfull requests + println!( + "Updates done for {} entries", + results + .iter() + .filter_map(|item| item.as_ref().ok()) + .filter(|item| item.response.is_ok()) + .count() + ); + for item in results { + match item { + Ok(value) => println!( + "{}", + match value.response { + Ok(val) => format!( + "Record '{}' ({}): {}", + value.entry_name, value.entry_type, val + ), + Err(err) => format!( + "Record '{}' ({}): {}", + value.entry_name, value.entry_type, err + ), + } + ), + Err(err) => println!("{}", err), + } } } else { println!("IP address has not changed since last update"); @@ -153,15 +266,14 @@ async fn run( #[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { let opts = opts::Opts::parse(); - let conf = config::load_config(&opts) - .die_with(|error| format!("Failed to read config file: {}", error)); + let conf = config::load_config(&opts)?; let ip_source: Box = match conf.ip_source { IPSourceName::Ipify => Box::new(IPSourceIpify), IPSourceName::Icanhazip => Box::new(IPSourceIcanhazip), IPSourceName::SeeIP => Box::new(IPSourceSeeIP), }; - config::validate_config(&conf).die_with(|error| format!("Invalid config: {}", error)); + config::validate_config(&conf)?; run("https://api.gandi.net", &ip_source, &conf, &opts).await?; Ok(()) } @@ -170,7 +282,7 @@ async fn main() -> anyhow::Result<()> { mod tests { use std::{env::temp_dir, time::Duration}; - use crate::{config, ip_source::ip_source::IPSource, opts::Opts, run}; + use crate::{config, ip_source::ip_source::IPSource, opts::Opts, run, ClientError}; use async_trait::async_trait; use httpmock::MockServer; use tokio::{fs, task::LocalSet, time::sleep}; @@ -179,10 +291,10 @@ mod tests { #[async_trait] impl IPSource for IPSourceMock { - async fn get_ipv4(&self) -> anyhow::Result { + async fn get_ipv4(&self) -> Result { Ok("192.168.0.0".to_string()) } - async fn get_ipv6(&self) -> anyhow::Result { + async fn get_ipv6(&self) -> Result { Ok("fe80:0000:0208:74ff:feda:625c".to_string()) } }