diff options
Diffstat (limited to 'src/bin/sso/main.rs')
| -rw-r--r-- | src/bin/sso/main.rs | 139 |
1 files changed, 111 insertions, 28 deletions
diff --git a/src/bin/sso/main.rs b/src/bin/sso/main.rs index a5055ba..c76a72a 100644 --- a/src/bin/sso/main.rs +++ b/src/bin/sso/main.rs @@ -15,14 +15,18 @@ use chrono::{DateTime, Duration, Utc}; use clap::{Parser, Subcommand}; use gethostname::gethostname; -use oauth2::basic::BasicClient; +use oauth2::basic::{ + BasicClient, BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, + BasicTokenType, +}; use oauth2::devicecode::StandardDeviceAuthorizationResponse; use oauth2::reqwest::http_client; use oauth2::{ - AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, RefreshToken, Scope, TokenResponse, - TokenUrl, + AuthType, AuthUrl, Client, ClientId, DeviceAuthorizationUrl, ExtraTokenFields, RefreshToken, + Scope, StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; +use sshcerts::{Certificate, PrivateKey}; use std::collections::{BTreeMap, HashSet}; use std::error::Error; use std::path::{Path, PathBuf}; @@ -70,6 +74,10 @@ struct SetupOptions { /// The jesterpm-sso endpoint #[clap(long)] endpoint: Option<Url>, + + /// An SSH key to have signed. + #[clap(long)] + ssh_key: Option<PathBuf>, } #[derive(Serialize, Deserialize, Clone)] @@ -81,8 +89,16 @@ struct Profile { refresh_token: Option<String>, #[serde(skip)] was_modified: bool, + ssh_key: Option<String>, } +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ExtraResponseFields { + pub ssh_certificate: Option<String>, +} + +impl ExtraTokenFields for ExtraResponseFields {} + impl Profile { /// Add a new scope to this profile. pub fn add_scope(&mut self, scope: String) { @@ -110,7 +126,14 @@ impl Profile { } pub fn authorize(&mut self, use_browser: bool) -> Result<(), Box<dyn Error>> { - let client = BasicClient::new(client_id(), None, self.auth_url(), Some(self.token_url())) + let client: Client< + BasicErrorResponse, + StandardTokenResponse<ExtraResponseFields, BasicTokenType>, + BasicTokenType, + BasicTokenIntrospectionResponse, + StandardRevocableToken, + BasicRevocationErrorResponse, + > = Client::new(client_id(), None, self.auth_url(), Some(self.token_url())) .set_auth_type(AuthType::RequestBody) .set_device_authorization_url(self.device_url()); @@ -122,10 +145,15 @@ impl Profile { .join(" "), ); - let details: StandardDeviceAuthorizationResponse = client - .exchange_device_code()? - .add_scope(scope) - .request(http_client)?; + let mut device_request = client.exchange_device_code()?.add_scope(scope); + + if let Some(ref filename) = self.ssh_key { + let private_key = PrivateKey::from_path(filename)?; + let pubkey = private_key.pubkey.to_string(); + device_request = device_request.add_extra_param("ssh_pubkey", pubkey); + } + + let details: StandardDeviceAuthorizationResponse = device_request.request(http_client)?; let mut quiet = false; @@ -145,11 +173,9 @@ impl Profile { ); } - let token_result = client.exchange_device_access_token(&details).request( - http_client, - std::thread::sleep, - None, - )?; + let token_result: StandardTokenResponse<ExtraResponseFields, BasicTokenType> = client + .exchange_device_access_token(&details) + .request(http_client, std::thread::sleep, None)?; self.access_token = Some(token_result.access_token().secret().to_string()); self.access_token_expiration = token_result @@ -157,6 +183,14 @@ impl Profile { .map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64)); self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string()); self.was_modified = true; + + // Save the new certificate + if let Some(ref cert) = token_result.extra_fields().ssh_certificate { + if let Some(cert_file) = self.ssh_certificate_file() { + fs::write(cert_file, cert)?; + } + } + Ok(()) } @@ -191,6 +225,11 @@ impl Profile { self.refresh_token = None; } + pub fn set_ssh_key(&mut self, ssh_key: Option<String>) { + self.ssh_key = ssh_key; + self.was_modified = true; + } + pub fn modified(&self) -> bool { self.was_modified } @@ -207,6 +246,29 @@ impl Profile { DeviceAuthorizationUrl::new(format!("{}/oauth/device", &self.endpoint)) .expect("Bad endpoint url.") } + + fn ssh_certificate_file(&self) -> Option<String> { + self.ssh_key + .as_ref() + .map(|f| { + let mut cert_file = f.to_owned(); + cert_file.push_str("-cert.pub"); + cert_file + }) + } + + fn is_ssh_certificate_valid(&self) -> Result<bool, Box<dyn Error>> { + if let Some(f) = self.ssh_certificate_file() { + let file = PathBuf::from(f); + if file.exists() { + let cert = Certificate::from_path(file)?; + let now = Utc::now().timestamp() as u64; + return Ok(now >= cert.valid_after && now <= cert.valid_before); + } + } + // No certificate + Ok(false) + } } impl Default for Profile { @@ -218,6 +280,7 @@ impl Default for Profile { access_token_expiration: None, refresh_token: None, was_modified: false, + ssh_key: None, } } } @@ -275,6 +338,31 @@ fn do_curl(profile: &Profile, mut args: Vec<String>) -> Result<(), Box<dyn Error .map_err(|e| e.into()) } +fn ensure_authorized(explicit_login: bool, profile: &mut Profile, browser: bool) -> Result<(), Box<dyn Error>> { + if !explicit_login && profile.valid_access_token() { + return Ok(()) + } + + let mut can_refresh = profile.valid_refresh_token(); + + if explicit_login && profile.ssh_key.is_some() && !profile.is_ssh_certificate_valid()? { + log::debug!("Full authorization required to refresh ssh certificate"); + can_refresh = false; + } + + if can_refresh { + log::debug!("Attempting to refresh access token"); + match profile.refresh() { + Ok(_) => return Ok(()), + Err(e) => log::info!("Failed to refresh token: {}", e), + } + } + + // Acquire credentials + log::debug!("Attempting to retrieve access token"); + profile.authorize(browser) +} + fn main() -> Result<(), Box<dyn Error>> { env_logger::init(); @@ -308,6 +396,11 @@ fn main() -> Result<(), Box<dyn Error>> { profile.set_endpoint(endpoint.to_string()); } + // Set the SSH key + if let Some(filename) = cfg.ssh_key { + profile.set_ssh_key(Some(filename.canonicalize()?.to_string_lossy().into_owned())); + } + // Print out the current configuration. println!("Profile {}", profile_name); println!("\tEndpoint: {}", profile.endpoint); @@ -316,6 +409,9 @@ fn main() -> Result<(), Box<dyn Error>> { .map(String::as_str) .collect::<Vec<_>>() .join(" ")); + if let Some(ref filename) = profile.ssh_key { + println!("\tSSH Key: {}", filename); + } if profile.modified() { save_profile(config_dir.as_path(), profile_name, &profile)?; @@ -323,21 +419,8 @@ fn main() -> Result<(), Box<dyn Error>> { return Ok(()); } - // Determine if we need a new token - if command == Commands::Login || !profile.valid_access_token() { - if profile.valid_refresh_token() { - // Try a refresh... - // Ignore any errors - if let Err(e) = profile.refresh() { - log::info!("Failed to refresh token: {}", e); - } - } - - if !profile.valid_access_token() { - // Acquire access token - profile.authorize(!args.no_browser)?; - } - } + // Everything after this point will need a token. + ensure_authorized(command == Commands::Login, &mut profile, !args.no_browser)?; if profile.modified() { save_profile(config_dir.as_path(), profile_name, &profile)?; |
