Skip to content

Commit a377d09

Browse files
committed
Update key access control
1 parent 8cf0c05 commit a377d09

File tree

14 files changed

+232
-91
lines changed

14 files changed

+232
-91
lines changed

engine/src/middlewares/auth.rs

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ use tracing::{info, info_span, Instrument};
1010
use tracing_opentelemetry::OpenTelemetrySpanExt;
1111

1212
use crate::{
13-
models::{session::Session, team::Team},
13+
models::{keys::Key, session::Session, team::Team},
1414
routes::error::HttpError,
1515
state::State,
1616
utils::hash::hash_session,
1717
};
1818

1919
pub enum UserAuth {
2020
User(Session, State),
21+
Key(Key, State),
2122
None(State),
2223
}
2324

@@ -60,39 +61,61 @@ impl<'a> ApiExtractor<'a> for UserAuth {
6061
.map(|x| x.replace("Bearer ", ""));
6162

6263
// Token could either be a session token or a pat token
63-
if token.is_none() {
64-
return Ok(UserAuth::None(state.clone()));
65-
}
66-
67-
let token = token.unwrap();
68-
69-
let cache_key = format!("session:{}", token);
70-
71-
let is_user = state
72-
.cache
73-
.raw
74-
.get_with(cache_key, async {
75-
// Use tracing events instead of spans to avoid Send issues
76-
info!("Cache miss for session: {}", token);
64+
let token = match token {
65+
Some(token) => token,
66+
None => return Ok(UserAuth::None(state.clone())),
67+
};
68+
69+
// check the token
70+
if token.starts_with("se_") {
71+
let cache_key = format!("session:{}", token);
72+
73+
let is_user = state
74+
.cache
75+
.raw
76+
.get_with(cache_key, async {
77+
// Use tracing events instead of spans to avoid Send issues
78+
info!("Cache miss for session: {}", token);
79+
80+
// Hash the token
81+
let hash = hash_session(&token);
82+
83+
// Check if active session exists with token
84+
let session = Session::try_access(&state.database, &hash)
85+
.await
86+
.unwrap()
87+
.ok_or(HttpError::Unauthorized)
88+
.unwrap();
89+
90+
serde_json::to_value(session).unwrap()
91+
})
92+
.await;
93+
94+
let session: Option<Session> = serde_json::from_value(is_user).ok();
95+
96+
if let Some(session) = session {
97+
return Ok(UserAuth::User(session, state.clone()));
98+
}
99+
} else if token.starts_with("k_") {
100+
let cache_key = format!("key:{}", token);
77101

78-
// Hash the token
102+
let is_key = state.cache.raw.get_with(cache_key, async {
79103
let hash = hash_session(&token);
80104

81-
// Check if active session exists with token
82-
let session = Session::try_access(&state.database, &hash)
105+
let key = Key::get_by_id(&state.database, hash.as_ref())
83106
.await
84107
.unwrap()
85108
.ok_or(HttpError::Unauthorized)
86109
.unwrap();
87110

88-
serde_json::to_value(session).unwrap()
89-
})
90-
.await;
111+
serde_json::to_value(key).unwrap()
112+
}).await;
91113

92-
let session: Option<Session> = serde_json::from_value(is_user).ok();
114+
let key: Option<Key> = serde_json::from_value(is_key).ok();
93115

94-
if let Some(session) = session {
95-
return Ok(UserAuth::User(session, state.clone()));
116+
if let Some(key) = key {
117+
return Ok(UserAuth::Key(key, state.clone()));
118+
}
96119
}
97120

98121
Err(HttpError::Unauthorized.into())
@@ -124,23 +147,28 @@ impl<'a> ApiExtractor<'a> for UserAuth {
124147
}
125148

126149
impl UserAuth {
127-
pub fn ok(&self) -> Option<&Session> {
150+
/// @deprecated
151+
pub fn ok_session(&self) -> Option<&Session> {
128152
match self {
129153
UserAuth::User(session, _) => Some(session),
154+
UserAuth::Key(_, _) => None,
130155
UserAuth::None(_) => None,
131156
}
132157
}
133158

134-
pub fn required(&self) -> Result<&Session> {
159+
/// @deprecated
160+
pub fn required_session(&self) -> Result<&Session> {
135161
match self {
136162
UserAuth::User(session, _) => Ok(session),
163+
UserAuth::Key(_, _) => Err(HttpError::Unauthorized.into()),
137164
UserAuth::None(_) => Err(HttpError::Unauthorized.into()),
138165
}
139166
}
140167

141168
pub fn user_id(&self) -> Option<&str> {
142169
match self {
143170
UserAuth::User(session, _) => Some(&session.user_id),
171+
UserAuth::Key(_, __) => None,
144172
UserAuth::None(_) => None,
145173
}
146174
}
@@ -168,7 +196,10 @@ impl UserAuth {
168196
}
169197

170198
Ok(())
171-
}
199+
},
200+
UserAuth::Key(key, _) => {
201+
Err(HttpError::Forbidden)
202+
},
172203
UserAuth::None(_) => Err(HttpError::Unauthorized),
173204
}
174205
}
@@ -190,7 +221,16 @@ impl UserAuth {
190221
async move {
191222
match self {
192223
UserAuth::User(session, state) => match resource
193-
.has_access_to(state, &session.user_id)
224+
.has_access(state, "user", &session.user_id)
225+
.await
226+
.map_err(HttpError::from)
227+
{
228+
Ok(true) => Ok(()),
229+
Ok(false) => Err(HttpError::Forbidden),
230+
Err(e) => Err(e),
231+
},
232+
UserAuth::Key(key, state) => match resource
233+
.has_access(state, &key.key_type, &key.key_resource)
194234
.await
195235
.map_err(HttpError::from)
196236
{
@@ -207,9 +247,11 @@ impl UserAuth {
207247
}
208248

209249
pub trait AccessibleResource: Debug {
210-
fn has_access_to(
250+
fn has_access(
211251
&self,
212252
state: &State,
213-
user_id: &str,
253+
// 'user' | 'site' | 'team'
254+
resource: &str,
255+
resource_id: &str,
214256
) -> impl std::future::Future<Output = Result<bool, HttpError>> + Send;
215257
}

engine/src/models/session/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ impl Session {
140140
resource: &impl AccessibleResource,
141141
) -> Result<(), HttpError> {
142142
resource
143-
.has_access_to(state, &self.user_id)
143+
.has_access(state, "user", &self.user_id)
144144
.await
145145
.map_err(HttpError::from)?;
146146
Ok(())

engine/src/models/site.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,22 +136,46 @@ impl Site {
136136
pub struct SiteId<'a>(pub &'a str);
137137

138138
impl<'a> AccessibleResource for SiteId<'a> {
139-
#[tracing::instrument(name = "has_access_to", skip(state))]
140-
async fn has_access_to(&self, state: &State, user_id: &str) -> Result<bool, HttpError> {
139+
#[tracing::instrument(name = "has_access", skip(state))]
140+
async fn has_access(
141+
&self,
142+
state: &State,
143+
resource: &str,
144+
resource_id: &str,
145+
) -> Result<bool, HttpError> {
141146
let cache_key = format!("site:{}", self.0);
142147

143148
let part_of_site = state.cache.raw.get_with(cache_key, async {
144-
// Verify that the user is a member of the team that owns the site
145-
let part_of_site = query_scalar!(
146-
"SELECT EXISTS (SELECT 1 FROM sites WHERE site_id = $1 AND team_id IN (SELECT team_id FROM user_teams WHERE user_id = $2) OR team_id IN (SELECT team_id FROM teams WHERE owner_id = $2))",
147-
self.0,
148-
user_id
149-
)
150-
.fetch_one(&state.database.pool)
151-
.await;
152-
153-
let part_of_site = part_of_site.ok().flatten().unwrap_or(false);
154-
serde_json::to_value(part_of_site).unwrap()
149+
if resource == "user" {
150+
// Verify that the user is a member of the team that owns the site
151+
let part_of_site = query_scalar!(
152+
"SELECT EXISTS (SELECT 1 FROM sites WHERE site_id = $1 AND team_id IN (SELECT team_id FROM user_teams WHERE user_id = $2) OR team_id IN (SELECT team_id FROM teams WHERE owner_id = $2))",
153+
self.0,
154+
resource_id
155+
)
156+
.fetch_one(&state.database.pool)
157+
.await;
158+
159+
let part_of_site = part_of_site.ok().flatten().unwrap_or(false);
160+
161+
serde_json::to_value(part_of_site).unwrap()
162+
} else if resource == "site" {
163+
if self.0 == resource_id {
164+
serde_json::to_value(true).unwrap()
165+
} else {
166+
serde_json::to_value(false).unwrap()
167+
}
168+
} else if resource == "team" {
169+
let site = Site::get_by_id(&state.database, resource_id).await.map_err(HttpError::from).ok();
170+
171+
let site_has_access = site
172+
.map(|s| s.team_id == resource_id)
173+
.unwrap_or(false);
174+
175+
serde_json::to_value(site_has_access).unwrap()
176+
} else {
177+
serde_json::to_value(false).unwrap()
178+
}
155179
}).await;
156180

157181
let part_of_site: bool = serde_json::from_value(part_of_site).unwrap_or_default();

engine/src/models/team/mod.rs

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ impl Team {
5252
}
5353

5454
#[tracing::instrument(name = "get_by_id", skip(db))]
55-
pub async fn get_by_id(db: &Database, team_id: impl AsRef<str> + Debug) -> Result<Self, sqlx::Error> {
55+
pub async fn get_by_id(
56+
db: &Database,
57+
team_id: impl AsRef<str> + Debug,
58+
) -> Result<Self, sqlx::Error> {
5659
let span = info_span!("Team::get_by_id");
5760
span.set_parent(Context::current());
5861
let _guard = span.enter();
@@ -127,11 +130,18 @@ impl Team {
127130
) -> Result<bool, sqlx::Error> {
128131
let cache_key = format!("team:{}:member:{}", team_id.as_ref(), user_id.as_ref());
129132

130-
let is_member = state.cache.raw.get_with(cache_key.clone(), async {
131-
let member = Team::_is_member(state.clone(), team_id, user_id).await.ok().unwrap_or(false);
133+
let is_member = state
134+
.cache
135+
.raw
136+
.get_with(cache_key.clone(), async {
137+
let member = Team::_is_member(state.clone(), team_id, user_id)
138+
.await
139+
.ok()
140+
.unwrap_or(false);
132141

133-
serde_json::Value::from(member)
134-
}).await;
142+
serde_json::Value::from(member)
143+
})
144+
.await;
135145

136146
let is_member: bool = serde_json::from_value(is_member).unwrap_or(false);
137147

@@ -175,31 +185,51 @@ impl Team {
175185
.await
176186
}
177187

178-
pub async fn add_member(db: &Database, team_id: impl AsRef<str>, user_id: impl AsRef<str>) -> Result<(), sqlx::Error> {
188+
pub async fn add_member(
189+
db: &Database,
190+
team_id: impl AsRef<str>,
191+
user_id: impl AsRef<str>,
192+
) -> Result<(), sqlx::Error> {
179193
let span = info_span!("Team::add_member");
180194
span.set_parent(Context::current());
181195
let _guard = span.enter();
182196

183-
query!("INSERT INTO user_teams (team_id, user_id) VALUES ($1, $2)", team_id.as_ref(), user_id.as_ref())
184-
.execute(&db.pool)
185-
.await?;
197+
query!(
198+
"INSERT INTO user_teams (team_id, user_id) VALUES ($1, $2)",
199+
team_id.as_ref(),
200+
user_id.as_ref()
201+
)
202+
.execute(&db.pool)
203+
.await?;
186204

187205
Ok(())
188206
}
189207

190-
pub async fn update_name(db: &Database, team_id: impl AsRef<str>, name: impl AsRef<str>) -> Result<(), sqlx::Error> {
208+
pub async fn update_name(
209+
db: &Database,
210+
team_id: impl AsRef<str>,
211+
name: impl AsRef<str>,
212+
) -> Result<(), sqlx::Error> {
191213
let span = info_span!("Team::update_name");
192214
span.set_parent(Context::current());
193215
let _guard = span.enter();
194216

195-
query!("UPDATE teams SET name = $2 WHERE team_id = $1", team_id.as_ref(), name.as_ref())
196-
.execute(&db.pool)
197-
.await?;
217+
query!(
218+
"UPDATE teams SET name = $2 WHERE team_id = $1",
219+
team_id.as_ref(),
220+
name.as_ref()
221+
)
222+
.execute(&db.pool)
223+
.await?;
198224

199225
Ok(())
200226
}
201227

202-
pub async fn update_avatar(db: &Database, team_id: impl AsRef<str>, avatar_url: impl AsRef<str>) -> Result<Team, sqlx::Error> {
228+
pub async fn update_avatar(
229+
db: &Database,
230+
team_id: impl AsRef<str>,
231+
avatar_url: impl AsRef<str>,
232+
) -> Result<Team, sqlx::Error> {
203233
let span = info_span!("Team::update_avatar");
204234
span.set_parent(Context::current());
205235
let _guard = span.enter();
@@ -219,11 +249,26 @@ impl Team {
219249
pub struct TeamId<'a>(pub &'a str);
220250

221251
impl<'a> AccessibleResource for TeamId<'a> {
222-
async fn has_access_to(&self, state: &State, user_id: &str) -> Result<bool, HttpError> {
223-
let x = Team::is_member(&state, self.0, user_id)
224-
.await
225-
.map_err(HttpError::from)?;
226-
227-
Ok(x)
252+
async fn has_access(
253+
&self,
254+
state: &State,
255+
resource: &str,
256+
resource_id: &str,
257+
) -> Result<bool, HttpError> {
258+
if resource == "user" {
259+
let x = Team::is_member(&state, self.0, resource_id)
260+
.await
261+
.map_err(HttpError::from)?;
262+
263+
Ok(x)
264+
} else if resource == "team" {
265+
if self.0 == resource_id {
266+
Ok(true)
267+
} else {
268+
Ok(false)
269+
}
270+
} else {
271+
Ok(false)
272+
}
228273
}
229274
}

engine/src/routes/invite/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl InviteApi {
9494
) -> Result<PlainText<String>> {
9595
info!("Accepting invite: {:?}", invite_id.0);
9696

97-
let user = user.required()?;
97+
let user = user.required_session()?;
9898

9999
let invite = UserTeamInvite::get_by_invite_id(&state.database, &invite_id.0)
100100
.await

engine/src/routes/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use poem::{
1313
};
1414
use poem_openapi::{OpenApi, OpenApiService, Tags};
1515
use serde_json::{self, Value};
16-
use team::TeamApi;
1716
use tracing::info;
1817
use user::UserApi;
1918

0 commit comments

Comments
 (0)