summaryrefslogtreecommitdiff
path: root/src/bin/sso/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/sso/main.rs')
-rw-r--r--src/bin/sso/main.rs276
1 files changed, 276 insertions, 0 deletions
diff --git a/src/bin/sso/main.rs b/src/bin/sso/main.rs
new file mode 100644
index 0000000..b3ca612
--- /dev/null
+++ b/src/bin/sso/main.rs
@@ -0,0 +1,276 @@
+//! Command line tool for getting access tokens
+//!
+//! Usage: sso [OPTIONS] [COMMAND] [ARGS]
+//!
+//! Options:
+//! --scope <SCOPE> Request an additional scope
+//! --endpoint <URL> The jesterpm-sso endpoint
+//!
+//! Commands:
+//! login - default: get or renew an access token
+//! curl - pass the
+
+use std::collections::{BTreeMap, HashSet};
+use std::{env, fs};
+use std::error::Error;
+use std::path::{Path, PathBuf};
+use std::process::Command;
+use chrono::{DateTime, Duration, Utc};
+use clap::{Parser, Subcommand};
+use serde::{Serialize, Deserialize};
+use oauth2::{AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, RefreshToken, Scope, TokenResponse, TokenUrl};
+use oauth2::basic::BasicClient;
+use oauth2::devicecode::StandardDeviceAuthorizationResponse;
+use oauth2::reqwest::http_client;
+use url::Url;
+use gethostname::gethostname;
+
+#[derive(Parser)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ /// The profile to use. A profile is an endpoint and set of scopes.
+ #[clap(short = 'P', long, default_value = "default")]
+ profile: String,
+
+ /// Request an additional scope
+ #[clap(short, long)]
+ scope: Vec<String>,
+
+ /// The jesterpm-sso endpoint
+ #[clap(long)]
+ endpoint: Option<Url>,
+
+ /// Turn debugging information on
+ #[clap(short, long, parse(from_occurrences))]
+ verbose: usize,
+
+ #[clap(subcommand)]
+ command: Option<Commands>,
+}
+
+#[derive(Subcommand, PartialEq)]
+enum Commands {
+ /// does testing things
+ Login,
+ Curl { args: Vec<String> }
+}
+
+#[derive(Serialize, Deserialize, Clone)]
+struct Profile {
+ endpoint: String,
+ scopes: HashSet<String>,
+ access_token: Option<String>,
+ access_token_expiration: Option<DateTime<Utc>>,
+ refresh_token: Option<String>,
+ #[serde(skip)]
+ was_modified: bool,
+}
+
+impl Profile {
+ /// Add a new scope to this profile.
+ pub fn add_scope(&mut self, scope: String) {
+ if self.scopes.insert(scope) {
+ // Since we didn't have this scope before, our old access
+ // and refresh tokens are useless.
+ self.access_token = None;
+ self.access_token_expiration = None;
+ self.refresh_token = None;
+ }
+ }
+
+ /// Check if the access token should be valid.
+ pub fn valid_access_token(&self) -> bool {
+ self.access_token.is_some() &&
+ self.access_token_expiration
+ .map(|expiration| Utc::now() < expiration)
+ .unwrap_or(true)
+ }
+
+ /// Check if there is a refresh token.
+ pub fn valid_refresh_token(&self) -> bool {
+ self.refresh_token.is_some()
+ }
+
+ pub fn authorize(&mut self, ) -> Result<(), Box<dyn Error>> {
+ let client = BasicClient::new(
+ client_id(),
+ None,
+ self.auth_url(),
+ Some(self.token_url()),
+ )
+ .set_auth_type(AuthType::RequestBody)
+ .set_device_authorization_url(self.device_url());
+
+ let scope = Scope::new(self.scopes.iter().map(|s| s.to_string()).collect::<Vec<String>>().join(" "));
+
+ let details: StandardDeviceAuthorizationResponse = client
+ .exchange_device_code()?
+ .add_scope(scope)
+ .request(http_client)?;
+
+ println!(
+ "Open this URL in your browser:\n{}\nand enter the code: {}",
+ details.verification_uri().to_string(),
+ details.user_code().secret().to_string()
+ );
+
+ let token_result =
+ client
+ .exchange_device_access_token(&details)
+ .request(http_client, std::thread::sleep, None)?;
+
+ self.access_token = Some(token_result.access_token().secret().to_string());
+ self.access_token_expiration = token_result.expires_in().map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64));
+ self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string());
+ self.was_modified = true;
+ Ok(())
+ }
+
+ pub fn refresh(&mut self) -> Result<(), Box<dyn Error>> {
+ let client =
+ BasicClient::new(
+ client_id(),
+ None,
+ self.auth_url(),
+ Some(self.token_url()),
+ )
+ .set_auth_type(AuthType::RequestBody);
+
+ let refresh_token = RefreshToken::new(self.refresh_token.as_deref().map(|s| s.to_string()).expect("Missing refresh token"));
+ let token_result = client.exchange_refresh_token(&refresh_token)
+ .request(http_client)?;
+
+ self.access_token = Some(token_result.access_token().secret().to_string());
+ self.access_token_expiration = token_result.expires_in().map(|d| Utc::now() + Duration::seconds(d.as_secs() as i64));
+ self.refresh_token = token_result.refresh_token().map(|t| t.secret().to_string());
+ self.was_modified = true;
+ Ok(())
+ }
+
+ pub fn set_endpoint(&mut self, endpoint: String) {
+ self.endpoint = endpoint;
+ self.was_modified = true;
+ self.access_token = None;
+ self.access_token_expiration = None;
+ self.refresh_token = None;
+ }
+
+ pub fn modified(&self) -> bool {
+ self.was_modified
+ }
+
+ fn auth_url(&self) -> AuthUrl {
+ AuthUrl::new(format!("{}/oauth/authorize", &self.endpoint))
+ .expect("Bad endpoint url.")
+ }
+
+ fn token_url(&self) -> TokenUrl {
+ TokenUrl::new(format!("{}/oauth/token", &self.endpoint))
+ .expect("Bad endpoint url.")
+ }
+
+ fn device_url(&self) -> DeviceAuthorizationUrl {
+ DeviceAuthorizationUrl::new(format!("{}/oauth/device", &self.endpoint))
+ .expect("Bad endpoint url.")
+ }
+}
+
+impl Default for Profile {
+ fn default() -> Self {
+ Profile {
+ endpoint: "https://login.jesterpm.net".to_string(),
+ scopes: HashSet::new(),
+ access_token: None,
+ access_token_expiration: None,
+ refresh_token: None,
+ was_modified: false,
+ }
+ }
+}
+
+fn client_id() -> ClientId {
+ ClientId::new(format!("device:{}", gethostname().to_string_lossy()))
+}
+
+fn load_profile(config_dir: &Path, profile_name: &str) -> Result<Profile, Box<dyn Error>> {
+ let filename = config_dir.join("profiles.json");
+ if filename.exists() {
+ let file = fs::File::open(filename)?;
+ let mut profiles: BTreeMap<String, Profile> = serde_json::from_reader(file)?;
+ Ok(profiles.remove(profile_name).unwrap_or_else(Profile::default))
+ } else {
+ Ok(Profile::default())
+ }
+}
+
+fn save_profile(config_dir: &Path, profile_name: &str, profile: &Profile) -> Result<(), Box<dyn Error>> {
+ let filename = config_dir.join("profiles.json");
+ let mut profiles: BTreeMap<String, Profile> = if filename.exists() {
+ let file = fs::File::open(&filename)?;
+ serde_json::from_reader(file)?
+ } else {
+ BTreeMap::new()
+ };
+
+ profiles.insert(profile_name.to_string(), profile.clone());
+
+ let file = fs::File::create(&filename)?;
+ serde_json::to_writer(file, &profiles)
+ .map_err(|e| e.into())
+}
+
+fn do_curl(profile: &Profile, mut args: Vec<String>) -> Result<(), Box<dyn Error>> {
+ args.push("-H".to_string());
+ args.push(format!("Authorization: Bearer {}", profile.access_token.as_deref().expect("Must have valid access token")));
+ Command::new("curl").args(args).spawn()?.wait().map(|_| ()).map_err(|e| e.into())
+
+}
+
+fn main() -> Result<(), Box<dyn Error>> {
+ let args: Args = Args::parse();
+
+ let command = args.command.unwrap_or(Commands::Login);
+
+ // Find the config files.
+ let home: PathBuf = env::var("HOME").expect("No $HOME?").parse().expect("Bad $HOME?");
+ let config_dir = home.join(".config/jesterpm-sso");
+ if !config_dir.exists() {
+ fs::create_dir(config_dir.as_path())?;
+ }
+
+ // Load the profile from the config.
+ let profile_name = args.profile.as_str();
+ let mut profile = load_profile(config_dir.as_path(), profile_name)?;
+
+ // Add any new scopes to the profile.
+ for scope in args.scope {
+ profile.add_scope(scope);
+ }
+
+ // Set the endpoint
+ if let Some(endpoint) = args.endpoint {
+ profile.set_endpoint(endpoint.to_string());
+ }
+
+ // Determine if we need a new token
+ if command == Commands::Login || !profile.valid_access_token() {
+ if profile.valid_refresh_token() {
+ // Try a refresh...
+ profile.refresh()?;
+ }
+
+ if !profile.valid_access_token() {
+ // Acquire access token
+ profile.authorize()?;
+ }
+ }
+
+ if profile.modified() {
+ save_profile(config_dir.as_path(), profile_name, &profile)?;
+ }
+
+ match command {
+ Commands::Login {} => { Ok(()) /* No-op, we already took care of it above */ },
+ Commands::Curl { args } => do_curl(&profile, args),
+ }
+} \ No newline at end of file