summaryrefslogtreecommitdiff
path: root/src/bin/sso/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/sso/main.rs')
-rw-r--r--src/bin/sso/main.rs139
1 files changed, 111 insertions, 28 deletions
diff --git a/src/bin/sso/main.rs b/src/bin/sso/main.rs
index a5055ba..c76a72a 100644
--- a/src/bin/sso/main.rs
+++ b/src/bin/sso/main.rs
@@ -15,14 +15,18 @@
use chrono::{DateTime, Duration, Utc};
use clap::{Parser, Subcommand};
use gethostname::gethostname;
-use oauth2::basic::BasicClient;
+use oauth2::basic::{
+ BasicClient, BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
+ BasicTokenType,
+};
use oauth2::devicecode::StandardDeviceAuthorizationResponse;
use oauth2::reqwest::http_client;
use oauth2::{
- AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, RefreshToken, Scope, TokenResponse,
- TokenUrl,
+ AuthType, AuthUrl, Client, ClientId, DeviceAuthorizationUrl, ExtraTokenFields, RefreshToken,
+ Scope, StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
+use sshcerts::{Certificate, PrivateKey};
use std::collections::{BTreeMap, HashSet};
use std::error::Error;
use std::path::{Path, PathBuf};
@@ -70,6 +74,10 @@ struct SetupOptions {
/// The jesterpm-sso endpoint
#[clap(long)]
endpoint: Option<Url>,
+
+ /// An SSH key to have signed.
+ #[clap(long)]
+ ssh_key: Option<PathBuf>,
}
#[derive(Serialize, Deserialize, Clone)]
@@ -81,8 +89,16 @@ struct Profile {
refresh_token: Option<String>,
#[serde(skip)]
was_modified: bool,
+ ssh_key: Option<String>,
}
+#[derive(Serialize, Deserialize, Debug, Clone)]
+struct ExtraResponseFields {
+ pub ssh_certificate: Option<String>,
+}
+
+impl ExtraTokenFields for ExtraResponseFields {}
+
impl Profile {
/// Add a new scope to this profile.
pub fn add_scope(&mut self, scope: String) {
@@ -110,7 +126,14 @@ impl Profile {
}
pub fn authorize(&mut self, use_browser: bool) -> Result<(), Box<dyn Error>> {
- let client = BasicClient::new(client_id(), None, self.auth_url(), Some(self.token_url()))
+ let client: Client<
+ BasicErrorResponse,
+ StandardTokenResponse<ExtraResponseFields, BasicTokenType>,
+ BasicTokenType,
+ BasicTokenIntrospectionResponse,
+ StandardRevocableToken,
+ BasicRevocationErrorResponse,
+ > = Client::new(client_id(), None, self.auth_url(), Some(self.token_url()))
.set_auth_type(AuthType::RequestBody)
.set_device_authorization_url(self.device_url());
@@ -122,10 +145,15 @@ impl Profile {
.join(" "),
);
- let details: StandardDeviceAuthorizationResponse = client
- .exchange_device_code()?
- .add_scope(scope)
- .request(http_client)?;
+ let mut device_request = client.exchange_device_code()?.add_scope(scope);
+
+ if let Some(ref filename) = self.ssh_key {
+ let private_key = PrivateKey::from_path(filename)?;
+ let pubkey = private_key.pubkey.to_string();
+ device_request = device_request.add_extra_param("ssh_pubkey", pubkey);
+ }
+
+ let details: StandardDeviceAuthorizationResponse = device_request.request(http_client)?;
let mut quiet = false;
@@ -145,11 +173,9 @@ impl Profile {
);
}
- let token_result = client.exchange_device_access_token(&details).request(
- http_client,
- std::thread::sleep,
- None,
- )?;
+ let token_result: StandardTokenResponse<ExtraResponseFields, BasicTokenType> = client
+ .exchange_device_access_token(&details)
+ .request(http_client, std::thread::sleep, None)?;
self.access_token = Some(token_result.access_token().secret().to_string());
self.access_token_expiration = token_result
@@ -157,6 +183,14 @@ impl Profile {
.map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64));
self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string());
self.was_modified = true;
+
+ // Save the new certificate
+ if let Some(ref cert) = token_result.extra_fields().ssh_certificate {
+ if let Some(cert_file) = self.ssh_certificate_file() {
+ fs::write(cert_file, cert)?;
+ }
+ }
+
Ok(())
}
@@ -191,6 +225,11 @@ impl Profile {
self.refresh_token = None;
}
+ pub fn set_ssh_key(&mut self, ssh_key: Option<String>) {
+ self.ssh_key = ssh_key;
+ self.was_modified = true;
+ }
+
pub fn modified(&self) -> bool {
self.was_modified
}
@@ -207,6 +246,29 @@ impl Profile {
DeviceAuthorizationUrl::new(format!("{}/oauth/device", &self.endpoint))
.expect("Bad endpoint url.")
}
+
+ fn ssh_certificate_file(&self) -> Option<String> {
+ self.ssh_key
+ .as_ref()
+ .map(|f| {
+ let mut cert_file = f.to_owned();
+ cert_file.push_str("-cert.pub");
+ cert_file
+ })
+ }
+
+ fn is_ssh_certificate_valid(&self) -> Result<bool, Box<dyn Error>> {
+ if let Some(f) = self.ssh_certificate_file() {
+ let file = PathBuf::from(f);
+ if file.exists() {
+ let cert = Certificate::from_path(file)?;
+ let now = Utc::now().timestamp() as u64;
+ return Ok(now >= cert.valid_after && now <= cert.valid_before);
+ }
+ }
+ // No certificate
+ Ok(false)
+ }
}
impl Default for Profile {
@@ -218,6 +280,7 @@ impl Default for Profile {
access_token_expiration: None,
refresh_token: None,
was_modified: false,
+ ssh_key: None,
}
}
}
@@ -275,6 +338,31 @@ fn do_curl(profile: &Profile, mut args: Vec<String>) -> Result<(), Box<dyn Error
.map_err(|e| e.into())
}
+fn ensure_authorized(explicit_login: bool, profile: &mut Profile, browser: bool) -> Result<(), Box<dyn Error>> {
+ if !explicit_login && profile.valid_access_token() {
+ return Ok(())
+ }
+
+ let mut can_refresh = profile.valid_refresh_token();
+
+ if explicit_login && profile.ssh_key.is_some() && !profile.is_ssh_certificate_valid()? {
+ log::debug!("Full authorization required to refresh ssh certificate");
+ can_refresh = false;
+ }
+
+ if can_refresh {
+ log::debug!("Attempting to refresh access token");
+ match profile.refresh() {
+ Ok(_) => return Ok(()),
+ Err(e) => log::info!("Failed to refresh token: {}", e),
+ }
+ }
+
+ // Acquire credentials
+ log::debug!("Attempting to retrieve access token");
+ profile.authorize(browser)
+}
+
fn main() -> Result<(), Box<dyn Error>> {
env_logger::init();
@@ -308,6 +396,11 @@ fn main() -> Result<(), Box<dyn Error>> {
profile.set_endpoint(endpoint.to_string());
}
+ // Set the SSH key
+ if let Some(filename) = cfg.ssh_key {
+ profile.set_ssh_key(Some(filename.canonicalize()?.to_string_lossy().into_owned()));
+ }
+
// Print out the current configuration.
println!("Profile {}", profile_name);
println!("\tEndpoint: {}", profile.endpoint);
@@ -316,6 +409,9 @@ fn main() -> Result<(), Box<dyn Error>> {
.map(String::as_str)
.collect::<Vec<_>>()
.join(" "));
+ if let Some(ref filename) = profile.ssh_key {
+ println!("\tSSH Key: {}", filename);
+ }
if profile.modified() {
save_profile(config_dir.as_path(), profile_name, &profile)?;
@@ -323,21 +419,8 @@ fn main() -> Result<(), Box<dyn Error>> {
return Ok(());
}
- // Determine if we need a new token
- if command == Commands::Login || !profile.valid_access_token() {
- if profile.valid_refresh_token() {
- // Try a refresh...
- // Ignore any errors
- if let Err(e) = profile.refresh() {
- log::info!("Failed to refresh token: {}", e);
- }
- }
-
- if !profile.valid_access_token() {
- // Acquire access token
- profile.authorize(!args.no_browser)?;
- }
- }
+ // Everything after this point will need a token.
+ ensure_authorized(command == Commands::Login, &mut profile, !args.no_browser)?;
if profile.modified() {
save_profile(config_dir.as_path(), profile_name, &profile)?;