Skip to content

Commit

Permalink
add diesel persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
carderne committed Apr 11, 2024
1 parent afd1634 commit fa8a506
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 43 deletions.
43 changes: 43 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"

[dependencies]
chrono = { version = "0.4.37", features = ["serde"] }
diesel = { version = "2.1.0", features = ["sqlite"] }
dotenvy = "0.15"
geo = "0.28.0"
geo-types = "0.7.13"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ API docs at [localhost:8000/rapidoc](http://localhost:8000/rapidoc).

Need the following in your `.env`:
```bash
DATABASE_URL='db.sqlite'
ROCKET_SECRET_KEY=''
REDIRECT_URI='http://localhost:8000/callback'

Expand Down
Empty file added migrations/.keep
Empty file.
1 change: 1 addition & 0 deletions migrations/2024-04-03-181833_create_users/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE users;
6 changes: 6 additions & 0 deletions migrations/2024-04-03-181833_create_users/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
refresh_token TEXT NOT NULL,
access_token TEXT NOT NULL,
expires_at INTEGER NOT NULL
);
36 changes: 36 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use diesel::prelude::*;
use diesel::{Connection, SqliteConnection};
use std::env;

use crate::models::UserDb;
use crate::{schema, strava};

fn establish_connection() -> SqliteConnection {
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
SqliteConnection::establish(&database_url)
.unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
}

pub fn save_user(t: &strava::TokenResponse) {
let user = UserDb {
id: t.athlete.id,
refresh_token: t.refresh_token.clone(),
access_token: t.access_token.clone(),
expires_at: t.expires_at,
};

let conn = &mut establish_connection();
diesel::insert_into(schema::users::table)
.values(&user)
.execute(conn)
.expect("Error saving new post");
}

pub fn get_user(user_id: i32) -> Option<UserDb> {
use self::schema::users::dsl::*;
let conn = &mut establish_connection();
users
.find(user_id)
.select(UserDb::as_select())
.first(conn).ok()
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
//! Playing around with the Strava API
pub mod data;
pub mod db;
pub mod h3;
pub mod models;
pub mod routes;
pub mod schema;
pub mod strava;
34 changes: 26 additions & 8 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
use chrono::{DateTime, NaiveDateTime, Utc};
use diesel::prelude::*;
use geojson::GeoJson;
use rocket::http::Status;
use rocket::request::Outcome;
use rocket::request::{FromRequest, Request};
use serde::Serialize;

#[derive(Debug)]
#[derive(Debug, Queryable, Selectable, Insertable)]
#[diesel(table_name = crate::schema::users)]
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
pub struct UserDb {
pub id: i32,
pub access_token: String,
pub refresh_token: String,
pub expires_at: i32,
}

pub struct User {
pub id: i32,
pub token: String,
}

#[rocket::async_trait]
Expand All @@ -19,18 +29,26 @@ impl<'r> FromRequest<'r> for User {
let id = jar
.get_private("id")
.and_then(|cookie| cookie.value().parse::<i32>().ok());
let token = jar
.get_private("token")
.map(|cookie| cookie.value().to_string());
match (id, token) {
(Some(id), Some(token)) => Outcome::Success(User { id, token }),
match id {
Some(id) => Outcome::Success(User { id }),
_ => Outcome::Forward(Status::Unauthorized),
}
}
}

#[derive(Serialize)]
pub struct Data {
pub activities: GeoJson,
pub activities: Option<GeoJson>,
pub cells: Vec<String>,
}

pub fn ts_to_dt(timestamp: i32) -> NaiveDateTime {
DateTime::from_timestamp(timestamp as i64, 0)
.unwrap()
.naive_utc()
}

pub fn is_dt_past(datetime: NaiveDateTime) -> bool {
let now_plus_one_hour = Utc::now().naive_local() + chrono::Duration::hours(1);
datetime < now_plus_one_hour
}
66 changes: 35 additions & 31 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ use rocket_dyn_templates::{context, Template};
use rocket_okapi::settings::UrlObject;
use rocket_okapi::{openapi, openapi_get_routes, rapidoc::*};
use std::env;
use time::{Duration, OffsetDateTime};

use crate::models::{Data, User};
use crate::{data, h3, strava};
use crate::models::{is_dt_past, ts_to_dt, Data, User, UserDb};
use crate::{data, db, h3, strava};

pub async fn build() -> Result<Rocket<Ignite>, Error> {
rocket::build()
Expand Down Expand Up @@ -69,28 +68,46 @@ fn index() -> Redirect {
#[openapi(skip)]
#[get("/")]
fn authed_index(user: User) -> Template {
let User { id, token: _ } = user;
let User { id } = user;
let os_key = env::var("OS_KEY").unwrap();
Template::render("map", context! { id, os_key })
}

#[openapi(skip)]
#[get("/data")]
async fn get_data(user: User) -> Json<Data> {
let User { id: _, token } = user;
// Should get user from db/session
// but for now just using code directly
// let user = db::get_user(id);
let User { id } = user;

let user = db::get_user(id);
let user: UserDb = match user {
Some(user) => user,
None => {
// TODO do something in the UI to handle this
return Json(Data {
activities: None,
cells: vec![],
});
}
};

let expiry = ts_to_dt(user.expires_at);
let expired = is_dt_past(expiry);

// previously I was just getting a token out of the cookie
// which was quite elegant, but didn't provide for refreshing...
let token = if expired {
// get a new token (using refresh_token) if this one expired
let token_response =
strava::get_token(&user.refresh_token, strava::GrantType::Refresh).await;
db::save_user(&token_response);
token_response.access_token
} else {
// otherwise use the current one
user.access_token
};

// TODO flash message on error (eg expired token)
let activities = strava::get_activities(&token).await;

// Could just return Json here and use a `fetch` call from UI
// Json(geo::decode_all(activities))

// But instead return template with injected GeoJSON
let activities = data::decode_all(activities);

let mut cells = h3::polyfill_all(&activities);
cells.sort();
cells.dedup();
Expand All @@ -100,7 +117,7 @@ async fn get_data(user: User) -> Json<Data> {
.collect();

let activities = data::to_geojson(activities);
Json(Data { activities, cells })
Json(Data { activities: Some(activities), cells })
}

#[openapi(tag = "OAuth")]
Expand All @@ -113,7 +130,8 @@ fn auth() -> Redirect {
#[openapi(tag = "OAuth")]
#[get("/callback?<code>")]
async fn callback(code: &str, jar: &CookieJar<'_>) -> Redirect {
let token_response = strava::get_token(code).await;
let token_response = strava::get_token(code, strava::GrantType::Auth).await;
db::save_user(&token_response);

let mut c_id: Cookie = Cookie::new("id", token_response.athlete.id.to_string());
// This happens after the OAuth flow and if SameSite::Strict
Expand All @@ -122,27 +140,13 @@ async fn callback(code: &str, jar: &CookieJar<'_>) -> Redirect {
c_id.set_same_site(SameSite::Lax);
jar.add_private(c_id);

// The Strava token is valid for six hours
// so expire the cookie after 5!
let mut c_token: Cookie = Cookie::new("token", token_response.access_token);
c_token.set_expires(OffsetDateTime::now_utc() + Duration::hours(5));
// Same comment as above about Strict
c_token.set_same_site(SameSite::Lax);
jar.add_private(c_token);

// Not saving users, rather just redirect with cookies
// db::save_user(token_response);
use std::{thread, time};
let ten_millis = time::Duration::from_millis(1000);
thread::sleep(ten_millis);
Redirect::to(uri!(authed_index))
}

#[openapi(skip)]
#[get("/logout")]
fn logout(jar: &CookieJar<'_>) -> Redirect {
jar.remove_private("id");
jar.remove_private("token");
Redirect::to(uri!(logged_out))
}

Expand Down
10 changes: 10 additions & 0 deletions src/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// @generated automatically by Diesel CLI.

diesel::table! {
users (id) {
id -> Integer,
refresh_token -> Text,
access_token -> Text,
expires_at -> Integer,
}
}
17 changes: 13 additions & 4 deletions src/strava.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ use serde::Deserialize;
use std::env;
use url::{ParseError, Url};

pub enum GrantType {
Auth,
Refresh,
}

#[derive(Deserialize)]
pub struct Athlete {
pub id: i32,
Expand Down Expand Up @@ -63,10 +68,14 @@ pub fn create_oauth_url() -> Result<String, ParseError> {
Ok(url.to_string())
}

fn create_token_url(code: &str) -> Result<String, ParseError> {
fn create_token_url(code: &str, grant_type: GrantType) -> Result<String, ParseError> {
let base = env::var("STRAVA_BASE").unwrap();
let client_id = env::var("STRAVA_CLIENT_ID").unwrap();
let client_secret = env::var("STRAVA_CLIENT_SECRET").unwrap();
let grant_type = match grant_type {
GrantType::Auth => "authorization_code",
GrantType::Refresh => "refresh_token",
};

let mut url = Url::parse(&base)?;
let path = "/oauth/token";
Expand All @@ -75,7 +84,7 @@ fn create_token_url(code: &str) -> Result<String, ParseError> {
.append_pair("client_id", &client_id)
.append_pair("client_secret", &client_secret)
.append_pair("code", code)
.append_pair("grant_type", "authorization_code");
.append_pair("grant_type", grant_type);
Ok(url.to_string())
}

Expand All @@ -94,8 +103,8 @@ pub async fn get_activities(token: &str) -> Vec<ActivityResponse> {
.unwrap()
}

pub async fn get_token(code: &str) -> TokenResponse {
let url = create_token_url(code).unwrap();
pub async fn get_token(code: &str, grant_type: GrantType) -> TokenResponse {
let url = create_token_url(code, grant_type).unwrap();
let client = reqwest::Client::new();
client
.post(url)
Expand Down

0 comments on commit fa8a506

Please sign in to comment.