1use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10use rand::RngCore;
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13use thiserror::Error;
14use tracing::{info, warn};
15
16#[derive(Clone, Debug)]
18pub struct OAuthConfig {
19 pub client_id: String,
20 pub client_secret: Option<String>,
22 pub auth_url: String,
23 pub token_url: String,
24 pub scopes: Vec<String>,
25 pub redirect_port: u16,
27 pub extra_auth_params: Vec<(String, String)>,
30}
31
32#[derive(Clone, Debug, Default)]
36pub struct OAuthClientCreds {
37 pub client_id: String,
38 pub client_secret: Option<String>,
40}
41
42static OAUTH_CLIENT_CREDS: OnceLock<HashMap<String, OAuthClientCreds>> = OnceLock::new();
43
44pub fn set_oauth_client_creds(creds: HashMap<String, OAuthClientCreds>) {
48 let _ = OAUTH_CLIENT_CREDS.set(creds);
49}
50
51pub fn oauth_client_creds(provider: &str) -> OAuthClientCreds {
53 OAUTH_CLIENT_CREDS
54 .get()
55 .and_then(|m| m.get(provider).cloned())
56 .unwrap_or_default()
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct OAuthTokens {
62 pub access_token: String,
63 pub refresh_token: Option<String>,
64 pub expires_at: Option<i64>,
66}
67
68#[derive(Error, Debug)]
69pub enum OAuthError {
70 #[error("failed to open browser: {0}")]
71 BrowserOpen(String),
72 #[error("callback server error: {0}")]
73 Server(String),
74 #[error("token exchange error: {0}")]
75 TokenExchange(String),
76 #[error("authorization denied: {0}")]
77 Denied(String),
78 #[error("timeout waiting for authorization callback")]
79 Timeout,
80}
81
82struct AbortOnDrop(Option<tokio::task::JoinHandle<()>>);
83
84impl AbortOnDrop {
85 fn new(handle: tokio::task::JoinHandle<()>) -> Self {
86 Self(Some(handle))
87 }
88
89 fn take_handle(mut self) -> Option<tokio::task::JoinHandle<()>> {
93 self.0.take()
94 }
95}
96
97impl Drop for AbortOnDrop {
98 fn drop(&mut self) {
99 if let Some(h) = self.0.take() {
100 h.abort();
101 }
102 }
103}
104
105#[derive(Deserialize)]
107struct TokenResponse {
108 access_token: String,
109 refresh_token: Option<String>,
110 expires_in: Option<i64>,
111 #[serde(default)]
112 error: Option<String>,
113 #[serde(default)]
114 error_description: Option<String>,
115}
116
117pub fn generate_code_verifier() -> String {
119 let mut bytes = [0u8; 32];
120 rand::rng().fill_bytes(&mut bytes);
121 URL_SAFE_NO_PAD.encode(bytes)
122}
123
124pub fn code_challenge(verifier: &str) -> String {
126 let hash = Sha256::digest(verifier.as_bytes());
127 URL_SAFE_NO_PAD.encode(hash)
128}
129
130pub async fn authorize(
140 config: &OAuthConfig,
141 cancel: tokio::sync::watch::Receiver<bool>,
142 clock: &dyn crate::clock::Clock,
143) -> Result<OAuthTokens, OAuthError> {
144 let verifier = generate_code_verifier();
145 let challenge = code_challenge(&verifier);
146 let redirect_uri = format!("http://localhost:{}/callback", config.redirect_port);
147
148 let mut auth_params = vec![
149 ("response_type", "code".to_string()),
150 ("client_id", config.client_id.clone()),
151 ("redirect_uri", redirect_uri.clone()),
152 ("code_challenge", challenge),
153 ("code_challenge_method", "S256".to_string()),
154 ];
155
156 for (k, v) in &config.extra_auth_params {
157 auth_params.push((k.as_str(), v.clone()));
158 }
159
160 if !config.scopes.is_empty() {
161 auth_params.push(("scope", config.scopes.join(" ")));
162 }
163
164 let auth_url = format!(
165 "{}?{}",
166 config.auth_url,
167 serde_urlencoded::to_string(&auth_params)
168 .map_err(|e| OAuthError::Server(format!("failed to encode params: {e}")))?
169 );
170
171 let (tx, rx) = tokio::sync::oneshot::channel::<Result<String, String>>();
173 let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
174
175 let tx_for_handler = tx.clone();
176 let app = axum::Router::new().route(
177 "/callback",
178 axum::routing::get(
179 move |axum::extract::Query(params): axum::extract::Query<
180 std::collections::HashMap<String, String>,
181 >| {
182 let tx = tx_for_handler.clone();
183 async move {
184 let mut guard = tx.lock().await;
185 let is_error = params.contains_key("error") || !params.contains_key("code");
186 if let Some(sender) = guard.take() {
187 if let Some(error) = params.get("error") {
188 let desc = params
189 .get("error_description")
190 .cloned()
191 .unwrap_or_else(|| error.clone());
192 let _ = sender.send(Err(desc));
193 } else if let Some(code) = params.get("code") {
194 let _ = sender.send(Ok(code.clone()));
195 } else {
196 let _ = sender.send(Err("no code in callback".to_string()));
197 }
198 }
199 let html = if is_error {
200 include_str!("oauth_success.html")
201 .replace("Authorization complete", "Authorization denied")
202 .replace(
203 "You can close this window and return to bae.",
204 "Authorization was denied. You can close this window and try again in bae.",
205 )
206 } else {
207 include_str!("oauth_success.html").to_string()
208 };
209 (
210 [
211 (axum::http::header::CACHE_CONTROL, "no-store"),
212 (axum::http::header::CONNECTION, "close"),
213 ],
214 axum::response::Html(html),
215 )
216 }
217 },
218 ),
219 );
220
221 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", config.redirect_port))
222 .await
223 .map_err(|e| OAuthError::Server(format!("failed to bind port: {e}")))?;
224
225 let server_guard = AbortOnDrop::new(tokio::spawn(async move {
228 if let Err(e) = axum::serve(listener, app)
229 .with_graceful_shutdown(async {
230 tokio::time::sleep(std::time::Duration::from_secs(300)).await;
231 })
232 .await
233 {
234 warn!("OAuth callback server exited with error: {e}");
235 }
236 }));
237
238 open::that(&auth_url).map_err(|e| OAuthError::BrowserOpen(format!("{e}")))?;
240
241 info!("Opened browser for OAuth authorization, waiting for callback");
242
243 let mut cancel = cancel;
245 let result = tokio::select! {
246 result = rx => {
247 result
248 .map_err(|_| OAuthError::Server("callback channel closed".to_string()))
249 .and_then(|r| r.map_err(OAuthError::Denied))
250 }
251 _ = cancel.wait_for(|&v| v) => {
252 Err(OAuthError::Denied("cancelled".to_string()))
253 }
254 _ = tokio::time::sleep(std::time::Duration::from_secs(300)) => {
255 Err(OAuthError::Timeout)
256 }
257 };
258
259 if let Some(handle) = server_guard.take_handle() {
265 handle.abort();
266 match tokio::time::timeout(std::time::Duration::from_millis(500), handle).await {
267 Ok(Ok(())) => {}
268 Ok(Err(e)) if e.is_cancelled() => {}
269 Ok(Err(e)) => {
270 warn!("OAuth callback server task panicked on shutdown: {e}");
271 }
272 Err(_) => {
273 warn!(
274 "OAuth callback server did not exit within 500ms; \
275 port {} may briefly remain in use",
276 config.redirect_port
277 );
278 }
279 }
280 }
281
282 let code = result?;
283
284 info!("Received authorization code, exchanging for tokens");
285
286 exchange_code(config, &code, &verifier, &redirect_uri, clock).await
288}
289
290pub async fn exchange_code(
292 config: &OAuthConfig,
293 code: &str,
294 verifier: &str,
295 redirect_uri: &str,
296 clock: &dyn crate::clock::Clock,
297) -> Result<OAuthTokens, OAuthError> {
298 let client = reqwest::Client::new();
299 let mut params = vec![
300 ("grant_type", "authorization_code"),
301 ("code", code),
302 ("redirect_uri", redirect_uri),
303 ("client_id", &config.client_id),
304 ("code_verifier", verifier),
305 ];
306
307 let secret_ref;
308 if let Some(ref secret) = config.client_secret {
309 secret_ref = secret.clone();
310 params.push(("client_secret", &secret_ref));
311 }
312
313 let resp = client
314 .post(&config.token_url)
315 .form(¶ms)
316 .send()
317 .await
318 .map_err(|e| OAuthError::TokenExchange(format!("request failed: {e}")))?;
319
320 let status = resp.status();
321 let body = resp
322 .text()
323 .await
324 .map_err(|e| OAuthError::TokenExchange(format!("read body: {e}")))?;
325
326 let token_resp: TokenResponse = serde_json::from_str(&body)
327 .map_err(|e| OAuthError::TokenExchange(format!("parse response: {e} (body: {body})")))?;
328
329 if let Some(error) = token_resp.error {
330 let desc = token_resp.error_description.unwrap_or(error);
331 return Err(OAuthError::TokenExchange(format!(
332 "provider error (HTTP {status}): {desc}"
333 )));
334 }
335
336 let expires_at = token_resp
337 .expires_in
338 .map(|secs| clock.now().timestamp() + secs);
339
340 Ok(OAuthTokens {
341 access_token: token_resp.access_token,
342 refresh_token: token_resp.refresh_token,
343 expires_at,
344 })
345}
346
347pub async fn refresh(
349 config: &OAuthConfig,
350 refresh_token: &str,
351 clock: &dyn crate::clock::Clock,
352) -> Result<OAuthTokens, OAuthError> {
353 let client = reqwest::Client::new();
354 let mut params = vec![
355 ("grant_type", "refresh_token"),
356 ("refresh_token", refresh_token),
357 ("client_id", &config.client_id),
358 ];
359
360 let secret_ref;
361 if let Some(ref secret) = config.client_secret {
362 secret_ref = secret.clone();
363 params.push(("client_secret", &secret_ref));
364 }
365
366 let resp = client
367 .post(&config.token_url)
368 .form(¶ms)
369 .send()
370 .await
371 .map_err(|e| OAuthError::TokenExchange(format!("refresh request failed: {e}")))?;
372
373 let status = resp.status();
374 let body = resp
375 .text()
376 .await
377 .map_err(|e| OAuthError::TokenExchange(format!("read body: {e}")))?;
378
379 let token_resp: TokenResponse = serde_json::from_str(&body)
380 .map_err(|e| OAuthError::TokenExchange(format!("parse response: {e} (body: {body})")))?;
381
382 if let Some(error) = token_resp.error {
383 let desc = token_resp.error_description.unwrap_or(error);
384 return Err(OAuthError::TokenExchange(format!(
385 "provider error (HTTP {status}): {desc}"
386 )));
387 }
388
389 let expires_at = token_resp
390 .expires_in
391 .map(|secs| clock.now().timestamp() + secs);
392
393 let new_refresh = token_resp
395 .refresh_token
396 .or_else(|| Some(refresh_token.to_string()));
397
398 Ok(OAuthTokens {
399 access_token: token_resp.access_token,
400 refresh_token: new_refresh,
401 expires_at,
402 })
403}
404
405pub async fn authorize_provider(
410 provider: crate::config::CloudProvider,
411 cancel: tokio::sync::watch::Receiver<bool>,
412 clock: &dyn crate::clock::Clock,
413) -> Result<OAuthTokens, OAuthError> {
414 use crate::config::CloudProvider;
415 use crate::storage::cloud::{dropbox, google_drive, onedrive};
416
417 let config = match provider {
418 CloudProvider::GoogleDrive => google_drive::GoogleDriveCloudHome::oauth_config(),
419 CloudProvider::Dropbox => dropbox::DropboxCloudHome::oauth_config(),
420 CloudProvider::OneDrive => onedrive::OneDriveCloudHome::oauth_config(),
421 other => {
422 return Err(OAuthError::Denied(format!("{other:?} does not use OAuth")));
423 }
424 };
425
426 authorize(&config, cancel, clock).await
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn pkce_verifier_is_url_safe() {
435 let verifier = generate_code_verifier();
436 assert!(verifier.len() >= 43);
437 assert!(verifier
438 .chars()
439 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
440 }
441
442 #[test]
443 fn pkce_challenge_is_deterministic() {
444 let verifier = "test-verifier-string";
445 let c1 = code_challenge(verifier);
446 let c2 = code_challenge(verifier);
447 assert_eq!(c1, c2);
448 }
449
450 #[test]
451 fn pkce_challenge_is_base64url() {
452 let verifier = generate_code_verifier();
453 let challenge = code_challenge(&verifier);
454 assert!(challenge
455 .chars()
456 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
457 }
458
459 #[test]
460 fn oauth_tokens_serialization_roundtrip() {
461 let tokens = OAuthTokens {
462 access_token: "at_123".to_string(),
463 refresh_token: Some("rt_456".to_string()),
464 expires_at: Some(1700000000),
465 };
466 let json = serde_json::to_string(&tokens).unwrap();
467 let parsed: OAuthTokens = serde_json::from_str(&json).unwrap();
468 assert_eq!(parsed.access_token, "at_123");
469 assert_eq!(parsed.refresh_token, Some("rt_456".to_string()));
470 assert_eq!(parsed.expires_at, Some(1700000000));
471 }
472}