diff --git a/migrations/sqlite/0002_alter_session_user_rename_pmsid_add_oauthid.sql b/migrations/sqlite/0002_alter_session_user_rename_pmsid_add_oauthid.sql new file mode 100644 index 0000000000000000000000000000000000000000..015420c92bcad70f1dcc74e37aecf1c134e4ff1d --- /dev/null +++ b/migrations/sqlite/0002_alter_session_user_rename_pmsid_add_oauthid.sql @@ -0,0 +1,3 @@ +ALTER TABLE session_user ADD COLUMN oauth_foreign_id TEXT; +ALTER TABLE session_user ADD COLUMN oauth_provider TEXT; +UPDATE session_user SET (oauth_foreign_id, oauth_provider) = (pms_id, "pms") WHERE pms_id IS NOT NULL; diff --git a/src/db_conn_sqlite.rs b/src/db_conn_sqlite.rs index 325a37b2939c22d4c7d529cf2a9e995b65b6261c..f51386fe050213c0c7f2d96dc7f447f0c0b6deff 100644 --- a/src/db_conn_sqlite.rs +++ b/src/db_conn_sqlite.rs @@ -43,7 +43,7 @@ impl MedalConnection for Connection { let tx = self.transaction().unwrap(); - tx.execute(&contents, &[]).unwrap(); + tx.execute_batch(&contents).unwrap(); tx.execute("INSERT INTO migrations (name) VALUES (?1)", &[&name]).unwrap(); tx.commit().unwrap(); @@ -53,7 +53,7 @@ impl MedalConnection for Connection { // fn get_session<T: ToSql>(&self, key: T, keyname: &str) -> Option<SessionUser> { fn get_session(&self, key: &str) -> Option<SessionUser> { - let res = self.query_row("SELECT id, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, managed_by, pms_id, pms_school_id, salt FROM session_user WHERE session_token = ?1", &[&key], |row| { + let res = self.query_row("SELECT id, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, managed_by, oauth_provider, oauth_foreign_id, salt FROM session_user WHERE session_token = ?1", &[&key], |row| { SessionUser { id: row.get(0), session_token: Some(key.to_string()), @@ -80,8 +80,9 @@ impl MedalConnection for Connection { is_teacher: row.get(18), managed_by: row.get(19), - pms_id: row.get(20), - pms_school_id: row.get(21), + + oauth_provider: row.get(20), + oauth_foreign_id: row.get(21), } }); match res { @@ -137,7 +138,7 @@ impl MedalConnection for Connection { } fn get_user_by_id(&self, user_id: u32) -> Option<SessionUser> { - let res = self.query_row("SELECT session_token, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, managed_by, pms_id, pms_school_id FROM session_user WHERE id = ?1", &[&user_id], |row| { + let res = self.query_row("SELECT session_token, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, managed_by, oauth_provider, oauth_foreign_id, salt FROM session_user WHERE id = ?1", &[&user_id], |row| { SessionUser { id: user_id, session_token: row.get(0), @@ -148,7 +149,7 @@ impl MedalConnection for Connection { username: row.get(5), password: row.get(6), - salt: None,//"".to_string(), + salt: row.get(22), logincode: row.get(7), email: row.get(8), email_unconfirmed: row.get(9), @@ -164,8 +165,9 @@ impl MedalConnection for Connection { is_teacher: row.get(18), managed_by: row.get(19), - pms_id: row.get(20), - pms_school_id: row.get(21), + + oauth_provider: row.get(20), + oauth_foreign_id: row.get(21), } }); res.ok() @@ -252,7 +254,7 @@ impl MedalConnection for Connection { let csrf_token: String = thread_rng().sample_iter(&Alphanumeric).take(10).collect(); let now = time::get_time(); - match self.query_row("SELECT id FROM session_user WHERE pms_id = ?1", &[&foreign_id], |row| -> u32 { + match self.query_row("SELECT id FROM session_user WHERE oauth_foreign_id = ?1", &[&foreign_id], |row| -> u32 { row.get(0) }) { Ok(id) => { @@ -262,7 +264,7 @@ impl MedalConnection for Connection { } // Add! _ => { - self.execute("INSERT INTO session_user (session_token, csrf_token, last_login, last_activity, permanent_login, grade, is_teacher, pms_id, firstname, lastname) VALUES (?1, ?2, ?3, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", &[&session_token, &csrf_token, &now, &false, &0, &(foreign_type != functions::UserType::User), &foreign_id, &firstname, &lastname]).unwrap(); + self.execute("INSERT INTO session_user (session_token, csrf_token, last_login, last_activity, permanent_login, grade, is_teacher, oauth_foreign_id, firstname, lastname) VALUES (?1, ?2, ?3, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", &[&session_token, &csrf_token, &now, &false, &0, &(foreign_type != functions::UserType::User), &foreign_id, &firstname, &lastname]).unwrap(); Ok(session_token) } @@ -653,36 +655,35 @@ impl MedalConnection for Connection { members: Vec::new() }) .unwrap(); // TODO handle error - let mut stmt = self.prepare("SELECT id, session_token, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, pms_id, pms_school_id FROM session_user WHERE managed_by = ?1").unwrap(); - let rows = stmt.query_map(&[&group_id], |row| { - SessionUser { id: row.get(0), - session_token: row.get(1), - csrf_token: row.get(2), - last_login: row.get(3), - last_activity: row.get(4), - permanent_login: row.get(5), - - username: row.get(6), - password: row.get(7), - salt: None, //"".to_string(), - logincode: row.get(8), - email: row.get(9), - email_unconfirmed: row.get(10), - email_confirmationcode: row.get(11), - - firstname: row.get(12), - lastname: row.get(13), - street: row.get(14), - zip: row.get(15), - city: row.get(16), - nation: row.get(17), - grade: row.get(18), - - is_teacher: row.get(19), - managed_by: Some(group_id), - pms_id: row.get(20), - pms_school_id: row.get(21) } - }) + let mut stmt = self.prepare("SELECT id, session_token, csrf_token, last_login, last_activity, permanent_login, username, password, logincode, email, email_unconfirmed, email_confirmationcode, firstname, lastname, street, zip, city, nation, grade, is_teacher, oauth_provider, oauth_foreign_id, salt FROM session_user WHERE managed_by = ?1").unwrap(); + let rows = stmt.query_map(&[&group_id], |row| SessionUser { id: row.get(0), + session_token: row.get(1), + csrf_token: row.get(2), + last_login: row.get(3), + last_activity: row.get(4), + permanent_login: row.get(5), + + username: row.get(6), + password: row.get(7), + salt: row.get(22), + logincode: row.get(8), + email: row.get(9), + email_unconfirmed: row.get(10), + email_confirmationcode: row.get(11), + + firstname: row.get(12), + lastname: row.get(13), + street: row.get(14), + zip: row.get(15), + city: row.get(16), + nation: row.get(17), + grade: row.get(18), + + is_teacher: row.get(19), + managed_by: Some(group_id), + + oauth_provider: row.get(20), + oauth_foreign_id: row.get(21) }) .unwrap(); for user in rows { @@ -825,7 +826,7 @@ impl MedalObject<Connection> for Grade { impl MedalObject<Connection> for Participation { fn save(&mut self, conn: &Connection) { - conn.execute("INSERT INTO participation (contest, user, start_date) VALUES (?1, ?2, ?3)", + conn.execute("INSERT INTO0 participation (contest, user, start_date) VALUES (?1, ?2, ?3)", &[&self.contest, &self.user, &self.start]) .unwrap(); } diff --git a/src/db_objects.rs b/src/db_objects.rs index ec07442522457cec24a023b4c39bde139964c0ad..0c8429c8f32461a08cece6ad9cb97c1dcb3f01c8 100644 --- a/src/db_objects.rs +++ b/src/db_objects.rs @@ -29,8 +29,13 @@ pub struct SessionUser { pub is_teacher: bool, pub managed_by: Option<u32>, - pub pms_id: Option<u32>, - pub pms_school_id: Option<u32>, + + pub oauth_foreign_id: Option<String>, + pub oauth_provider: Option<String>, + // pub oauth_extra_data: Option<String>, + + // pub pms_id: Option<u32>, + // pub pms_school_id: Option<u32>, } // Short version for display @@ -176,8 +181,14 @@ impl SessionUser { is_teacher: false, managed_by: None, - pms_id: None, - pms_school_id: None } + + oauth_foreign_id: None, + oauth_provider: None, + // oauth_extra_data: Option<String>, + + //pms_id: None, + //pms_school_id: None, + } } pub fn ensure_alive(self) -> Option<Self> { @@ -191,7 +202,7 @@ impl SessionUser { } pub fn ensure_logged_in(self) -> Option<Self> { - if self.password.is_some() || self.logincode.is_some() || self.pms_id.is_some() { + if self.password.is_some() || self.logincode.is_some() || self.oauth_foreign_id.is_some() { self.ensure_alive() } else { None diff --git a/src/functions.rs b/src/functions.rs index e21595e54a8a519e4a386a69652ea74b71302588..78f6801175532e5e05e74a0ef09c02d2aeb2258f 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -51,8 +51,10 @@ type MedalValue = (String, json_val::Map<String, json_val::Value>); type MedalResult<T> = Result<T, MedalError>; type MedalValueResult = MedalResult<MedalValue>; +use oauth_provider::OauthProvider; + pub fn index<T: MedalConnection>(conn: &T, session_token: Option<String>, - (self_url, oauth_url): (Option<String>, Option<String>)) + (self_url, oauth_providers): (Option<String>, Option<Vec<OauthProvider>>)) -> (String, json_val::Map<String, json_val::Value>) { let mut data = json_val::Map::new(); @@ -69,8 +71,19 @@ pub fn index<T: MedalConnection>(conn: &T, session_token: Option<String>, } } + let mut oauth_links: Vec<(String, String, String)> = Vec::new(); + if let Some(oauth_providers) = oauth_providers { + println!("tblub"); + for oauth_provider in oauth_providers { + oauth_links.push((oauth_provider.provider_id.to_owned(), + oauth_provider.login_link_text.to_owned(), + oauth_provider.url.to_owned())); + println!("testayy {}", oauth_provider.provider_id.to_owned()); + } + } + data.insert("self_url".to_string(), to_json(&self_url)); - data.insert("oauth_url".to_string(), to_json(&oauth_url)); + data.insert("oauth_links".to_string(), to_json(&oauth_links)); /*contests.push("blaa".to_string()); data.insert("contest".to_string(), to_json(&contests));*/ diff --git a/src/main.rs b/src/main.rs index 84172418b96e274ab63bfc470e19ecfe9e1c16b8..3ef8ae878c8a07acfab9876378ee2791a05db800 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,16 +47,14 @@ use std::path; use std::path::{Path, PathBuf}; use structopt::StructOpt; -#[derive(Serialize, Deserialize, Clone, Default)] +mod oauth_provider; + +#[derive(Serialize, Deserialize, Clone, Default, Debug)] pub struct Config { host: Option<String>, port: Option<u16>, self_url: Option<String>, - oauth_url: Option<String>, - oauth_client_id: Option<String>, - oauth_client_secret: Option<String>, - oauth_access_token_url: Option<String>, - oauth_user_data_url: Option<String>, + oauth_providers: Option<Vec<oauth_provider::OauthProvider>>, database_file: Option<PathBuf>, } @@ -74,6 +72,13 @@ fn read_config_from_file(file: &Path) -> Config { Default::default() }; + if let Some(ref oap) = config.oauth_providers { + println!("OAuth providers:"); + for oap in oap { + println!(" * {}", oap.provider_id); + } + } + if config.host.is_none() { config.host = Some("[::]".to_string()) } @@ -185,10 +190,10 @@ fn add_admin_user(conn: &mut Connection, resetpw: bool) { admin.username = Some("admin".into()); match admin.set_password(&password) { - None => println!("FAILED! (Password hashing error)"), + None => println!(" FAILED! (Password hashing error)"), _ => { conn.save_session(admin); - println!("Done"); + println!(" Done"); } } } diff --git a/src/oauth_provider.rs b/src/oauth_provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..bbb59a9f72ae34f1baaf95f8308fe59e3a2264ba --- /dev/null +++ b/src/oauth_provider.rs @@ -0,0 +1,10 @@ +#[derive(Serialize, Deserialize, Clone, Default, Debug)] +pub struct OauthProvider { + pub provider_id: String, + pub url: String, + pub client_id: String, + pub client_secret: String, + pub access_token_url: String, + pub user_data_url: String, + pub login_link_text: String, +} diff --git a/src/webfw_iron.rs b/src/webfw_iron.rs index e13fb3a714318542e619a75380f045e608871e09..b8956d2ddfd160206beb781a0124924f164e9f83 100644 --- a/src/webfw_iron.rs +++ b/src/webfw_iron.rs @@ -155,6 +155,7 @@ trait RequestRouterParam { fn get_str(self: &mut Self, key: &str) -> Option<String>; fn get_int<T: ::std::str::FromStr>(self: &mut Self, key: &str) -> Option<T>; fn expect_int<T: ::std::str::FromStr>(self: &mut Self, key: &str) -> IronResult<T>; + fn expect_str(self: &mut Self, key: &str) -> IronResult<String>; } impl<'a, 'b> RequestRouterParam for Request<'a, 'b> { @@ -174,6 +175,15 @@ impl<'a, 'b> RequestRouterParam for Request<'a, 'b> { response: Response::with(status::Forbidden) }), } } + + fn expect_str(self: &mut Self, key: &str) -> IronResult<String> { + match self.get_str(key) { + Some(s) => Ok(s), + _ => Err(IronError { error: Box::new(SessionError { message: + "Routing parameter missing".to_string() }), + response: Response::with(status::Forbidden) }), + } + } } use functions; @@ -231,10 +241,10 @@ fn greet_personal(req: &mut Request) -> IronResult<Response> { let session_token = req.get_session_token(); // hier ggf. Daten aus dem Request holen - let (self_url, oauth_url) = { + let (self_url, oauth_providers) = { let mutex = req.get::<Write<SharedConfiguration>>().unwrap(); let config = mutex.lock().unwrap_or_else(|e| e.into_inner()); - (config.self_url.clone(), config.oauth_url.clone()) + (config.self_url.clone(), config.oauth_providers.clone()) }; let (template, data) = { @@ -243,7 +253,7 @@ fn greet_personal(req: &mut Request) -> IronResult<Response> { let conn = mutex.lock().unwrap_or_else(|e| e.into_inner()); // Antwort erstellen und zurücksenden - functions::index(&*conn, session_token, (self_url, oauth_url)) + functions::index(&*conn, session_token, (self_url, oauth_providers)) }; // Daten verarbeiten @@ -299,16 +309,18 @@ fn contest_post(req: &mut Request) -> IronResult<Response> { } fn login(req: &mut Request) -> IronResult<Response> { - let (self_url, oauth_url) = { + // TODO: Use OAuth providers + let (self_url, _oauth_providers) = { let mutex = req.get::<Write<SharedConfiguration>>().unwrap(); let config = mutex.lock().unwrap_or_else(|e| e.into_inner()); - (config.self_url.clone(), config.oauth_url.clone()) + (config.self_url.clone(), config.oauth_providers.clone()) }; let mut data = json_val::Map::new(); data.insert("self_url".to_string(), to_json(&self_url)); - data.insert("oauth_url".to_string(), to_json(&oauth_url)); + // TODO: Generate list of links as in greet_personal + // data.insert("oauth_url".to_string(), to_json(&oauth_url)); let mut resp = Response::new(); resp.set_mut(Template::new("login", data)).set_mut(status::Ok); @@ -588,15 +600,30 @@ fn oauth(req: &mut Request) -> IronResult<Response> { use params::{Params, Value}; use reqwest::header; + let oauth_id = req.expect_str("oauthid")?; + let (client_id, client_secret, access_token_url, user_data_url) = { let mutex = req.get::<Write<SharedConfiguration>>().unwrap(); let config = mutex.lock().unwrap_or_else(|e| e.into_inner()); - if let (Some(id), Some(secret), Some(atu), Some(udu)) = (&config.oauth_client_id, - &config.oauth_client_secret, - &config.oauth_access_token_url, - &config.oauth_user_data_url) - { - (id.clone(), secret.clone(), atu.clone(), udu.clone()) + + let mut result: Option<(String, String, String, String)> = None; + + if let Some(ref oauth_providers) = config.oauth_providers { + for oauth_provider in oauth_providers { + if oauth_provider.provider_id == oauth_id { + result = Some((oauth_provider.client_id.clone(), + oauth_provider.client_secret.clone(), + oauth_provider.access_token_url.clone(), + oauth_provider.user_data_url.clone())); + break; + } + } + + if let Some(result) = result { + result + } else { + return Ok(Response::with(iron::status::NotFound)); + } } else { return Ok(Response::with(iron::status::NotFound)); } @@ -765,10 +792,11 @@ pub fn start_server(conn: Connection, config: ::Config) -> iron::error::HttpResu user: get "/user/:userid" => user, user_post: post "/user/:userid" => user_post, task: get "/task/:taskid" => task, - oauth: get "/oauth" => oauth, + oauth: get "/oauth/:oauthid" => oauth, check_cookie: get "/cookie" => cookie_warning, ); + // TODO: how important is this? Should this be in the config? let my_secret = b"verysecret".to_vec(); let mut mount = Mount::new(); diff --git a/templates/index.hbs b/templates/index.hbs index d7534ff83ef6d99f3985f22f55c5d7c29bb4eb2c..6056d5f00ea3d7fd38476e9e2f7da02c7c38c050 100644 --- a/templates/index.hbs +++ b/templates/index.hbs @@ -51,16 +51,17 @@ <div class="column"><input class="button is-small is-success" type="submit" value="↪ Login"></div> </div> </form> - {{#if oauth_url}}{{#if self_url}} - <div class="columns blogin"> - <div class="column is-two-fifths"></div> - <div class="column"> - <a class="button is-small is-info" href="{{ oauth_url }}{{ self_url }}/oauth">PMS-Login für Lehrer</a> + {{#if self_url}} + {{#each oauth_links}} + <div class="columns blogin"> + <div class="column is-two-fifths"></div> + <div class="column"> + <a class="button is-small is-info" href="{{ this.2 }}{{ self_url }}/oauth/{{ this.0 }}">{{ this.1 }}</a> + </div> </div> - </div> - {{/if}}{{/if}} + {{/each}} + {{/if}} {{/if}} - </div> </div>