diff options
Diffstat (limited to 'src/bin/sso/main.rs')
-rw-r--r-- | src/bin/sso/main.rs | 276 |
1 files changed, 276 insertions, 0 deletions
diff --git a/src/bin/sso/main.rs b/src/bin/sso/main.rs new file mode 100644 index 0000000..b3ca612 --- /dev/null +++ b/src/bin/sso/main.rs @@ -0,0 +1,276 @@ +//! Command line tool for getting access tokens +//! +//! Usage: sso [OPTIONS] [COMMAND] [ARGS] +//! +//! Options: +//! --scope <SCOPE> Request an additional scope +//! --endpoint <URL> The jesterpm-sso endpoint +//! +//! Commands: +//! login - default: get or renew an access token +//! curl - pass the + +use std::collections::{BTreeMap, HashSet}; +use std::{env, fs}; +use std::error::Error; +use std::path::{Path, PathBuf}; +use std::process::Command; +use chrono::{DateTime, Duration, Utc}; +use clap::{Parser, Subcommand}; +use serde::{Serialize, Deserialize}; +use oauth2::{AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, RefreshToken, Scope, TokenResponse, TokenUrl}; +use oauth2::basic::BasicClient; +use oauth2::devicecode::StandardDeviceAuthorizationResponse; +use oauth2::reqwest::http_client; +use url::Url; +use gethostname::gethostname; + +#[derive(Parser)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// The profile to use. A profile is an endpoint and set of scopes. + #[clap(short = 'P', long, default_value = "default")] + profile: String, + + /// Request an additional scope + #[clap(short, long)] + scope: Vec<String>, + + /// The jesterpm-sso endpoint + #[clap(long)] + endpoint: Option<Url>, + + /// Turn debugging information on + #[clap(short, long, parse(from_occurrences))] + verbose: usize, + + #[clap(subcommand)] + command: Option<Commands>, +} + +#[derive(Subcommand, PartialEq)] +enum Commands { + /// does testing things + Login, + Curl { args: Vec<String> } +} + +#[derive(Serialize, Deserialize, Clone)] +struct Profile { + endpoint: String, + scopes: HashSet<String>, + access_token: Option<String>, + access_token_expiration: Option<DateTime<Utc>>, + refresh_token: Option<String>, + #[serde(skip)] + was_modified: bool, +} + +impl Profile { + /// Add a new scope to this profile. + pub fn add_scope(&mut self, scope: String) { + if self.scopes.insert(scope) { + // Since we didn't have this scope before, our old access + // and refresh tokens are useless. + self.access_token = None; + self.access_token_expiration = None; + self.refresh_token = None; + } + } + + /// Check if the access token should be valid. + pub fn valid_access_token(&self) -> bool { + self.access_token.is_some() && + self.access_token_expiration + .map(|expiration| Utc::now() < expiration) + .unwrap_or(true) + } + + /// Check if there is a refresh token. + pub fn valid_refresh_token(&self) -> bool { + self.refresh_token.is_some() + } + + pub fn authorize(&mut self, ) -> Result<(), Box<dyn Error>> { + let client = BasicClient::new( + client_id(), + None, + self.auth_url(), + Some(self.token_url()), + ) + .set_auth_type(AuthType::RequestBody) + .set_device_authorization_url(self.device_url()); + + let scope = Scope::new(self.scopes.iter().map(|s| s.to_string()).collect::<Vec<String>>().join(" ")); + + let details: StandardDeviceAuthorizationResponse = client + .exchange_device_code()? + .add_scope(scope) + .request(http_client)?; + + println!( + "Open this URL in your browser:\n{}\nand enter the code: {}", + details.verification_uri().to_string(), + details.user_code().secret().to_string() + ); + + let token_result = + client + .exchange_device_access_token(&details) + .request(http_client, std::thread::sleep, None)?; + + self.access_token = Some(token_result.access_token().secret().to_string()); + self.access_token_expiration = token_result.expires_in().map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64)); + self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string()); + self.was_modified = true; + Ok(()) + } + + pub fn refresh(&mut self) -> Result<(), Box<dyn Error>> { + let client = + BasicClient::new( + client_id(), + None, + self.auth_url(), + Some(self.token_url()), + ) + .set_auth_type(AuthType::RequestBody); + + let refresh_token = RefreshToken::new(self.refresh_token.as_deref().map(|s| s.to_string()).expect("Missing refresh token")); + let token_result = client.exchange_refresh_token(&refresh_token) + .request(http_client)?; + + self.access_token = Some(token_result.access_token().secret().to_string()); + self.access_token_expiration = token_result.expires_in().map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64)); + self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string()); + self.was_modified = true; + Ok(()) + } + + pub fn set_endpoint(&mut self, endpoint: String) { + self.endpoint = endpoint; + self.was_modified = true; + self.access_token = None; + self.access_token_expiration = None; + self.refresh_token = None; + } + + pub fn modified(&self) -> bool { + self.was_modified + } + + fn auth_url(&self) -> AuthUrl { + AuthUrl::new(format!("{}/oauth/authorize", &self.endpoint)) + .expect("Bad endpoint url.") + } + + fn token_url(&self) -> TokenUrl { + TokenUrl::new(format!("{}/oauth/token", &self.endpoint)) + .expect("Bad endpoint url.") + } + + fn device_url(&self) -> DeviceAuthorizationUrl { + DeviceAuthorizationUrl::new(format!("{}/oauth/device", &self.endpoint)) + .expect("Bad endpoint url.") + } +} + +impl Default for Profile { + fn default() -> Self { + Profile { + endpoint: "https://login.jesterpm.net".to_string(), + scopes: HashSet::new(), + access_token: None, + access_token_expiration: None, + refresh_token: None, + was_modified: false, + } + } +} + +fn client_id() -> ClientId { + ClientId::new(format!("device:{}", gethostname().to_string_lossy())) +} + +fn load_profile(config_dir: &Path, profile_name: &str) -> Result<Profile, Box<dyn Error>> { + let filename = config_dir.join("profiles.json"); + if filename.exists() { + let file = fs::File::open(filename)?; + let mut profiles: BTreeMap<String, Profile> = serde_json::from_reader(file)?; + Ok(profiles.remove(profile_name).unwrap_or_else(Profile::default)) + } else { + Ok(Profile::default()) + } +} + +fn save_profile(config_dir: &Path, profile_name: &str, profile: &Profile) -> Result<(), Box<dyn Error>> { + let filename = config_dir.join("profiles.json"); + let mut profiles: BTreeMap<String, Profile> = if filename.exists() { + let file = fs::File::open(&filename)?; + serde_json::from_reader(file)? + } else { + BTreeMap::new() + }; + + profiles.insert(profile_name.to_string(), profile.clone()); + + let file = fs::File::create(&filename)?; + serde_json::to_writer(file, &profiles) + .map_err(|e| e.into()) +} + +fn do_curl(profile: &Profile, mut args: Vec<String>) -> Result<(), Box<dyn Error>> { + args.push("-H".to_string()); + args.push(format!("Authorization: Bearer {}", profile.access_token.as_deref().expect("Must have valid access token"))); + Command::new("curl").args(args).spawn()?.wait().map(|_| ()).map_err(|e| e.into()) + +} + +fn main() -> Result<(), Box<dyn Error>> { + let args: Args = Args::parse(); + + let command = args.command.unwrap_or(Commands::Login); + + // Find the config files. + let home: PathBuf = env::var("HOME").expect("No $HOME?").parse().expect("Bad $HOME?"); + let config_dir = home.join(".config/jesterpm-sso"); + if !config_dir.exists() { + fs::create_dir(config_dir.as_path())?; + } + + // Load the profile from the config. + let profile_name = args.profile.as_str(); + let mut profile = load_profile(config_dir.as_path(), profile_name)?; + + // Add any new scopes to the profile. + for scope in args.scope { + profile.add_scope(scope); + } + + // Set the endpoint + if let Some(endpoint) = args.endpoint { + profile.set_endpoint(endpoint.to_string()); + } + + // Determine if we need a new token + if command == Commands::Login || !profile.valid_access_token() { + if profile.valid_refresh_token() { + // Try a refresh... + profile.refresh()?; + } + + if !profile.valid_access_token() { + // Acquire access token + profile.authorize()?; + } + } + + if profile.modified() { + save_profile(config_dir.as_path(), profile_name, &profile)?; + } + + match command { + Commands::Login {} => { Ok(()) /* No-op, we already took care of it above */ }, + Commands::Curl { args } => do_curl(&profile, args), + } +}
\ No newline at end of file |