use std::{borrow::Cow, fmt};
use bytes::BufMut;
use ruma_common::{
api::{error::IntoHttpError, EndpointError, OutgoingResponse},
serde::{from_raw_json_value, JsonObject, StringEnum},
thirdparty::Medium,
OwnedClientSecret, OwnedSessionId, OwnedUserId,
};
use serde::{
de::{self, DeserializeOwned},
Deserialize, Deserializer, Serialize,
};
use serde_json::{
from_slice as from_json_slice, value::RawValue as RawJsonValue, Value as JsonValue,
};
use crate::{
error::{Error as MatrixError, StandardErrorBody},
PrivOwnedStr,
};
pub mod get_uiaa_fallback_page;
mod user_serde;
#[derive(Clone, Serialize)]
#[non_exhaustive]
#[serde(untagged)]
pub enum AuthData {
Password(Password),
ReCaptcha(ReCaptcha),
EmailIdentity(EmailIdentity),
Msisdn(Msisdn),
Dummy(Dummy),
RegistrationToken(RegistrationToken),
FallbackAcknowledgement(FallbackAcknowledgement),
Terms(Terms),
#[doc(hidden)]
_Custom(CustomAuthData),
}
impl AuthData {
pub fn new(
auth_type: &str,
session: Option<String>,
data: JsonObject,
) -> serde_json::Result<Self> {
fn deserialize_variant<T: DeserializeOwned>(
session: Option<String>,
mut obj: JsonObject,
) -> serde_json::Result<T> {
if let Some(session) = session {
obj.insert("session".into(), session.into());
}
serde_json::from_value(JsonValue::Object(obj))
}
Ok(match auth_type {
"m.login.password" => Self::Password(deserialize_variant(session, data)?),
"m.login.recaptcha" => Self::ReCaptcha(deserialize_variant(session, data)?),
"m.login.email.identity" => Self::EmailIdentity(deserialize_variant(session, data)?),
"m.login.msisdn" => Self::Msisdn(deserialize_variant(session, data)?),
"m.login.dummy" => Self::Dummy(deserialize_variant(session, data)?),
"m.registration_token" => Self::RegistrationToken(deserialize_variant(session, data)?),
"m.login.terms" => Self::Terms(deserialize_variant(session, data)?),
_ => {
Self::_Custom(CustomAuthData { auth_type: auth_type.into(), session, extra: data })
}
})
}
pub fn fallback_acknowledgement(session: String) -> Self {
Self::FallbackAcknowledgement(FallbackAcknowledgement::new(session))
}
pub fn auth_type(&self) -> Option<AuthType> {
match self {
Self::Password(_) => Some(AuthType::Password),
Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
Self::Msisdn(_) => Some(AuthType::Msisdn),
Self::Dummy(_) => Some(AuthType::Dummy),
Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
Self::FallbackAcknowledgement(_) => None,
Self::Terms(_) => Some(AuthType::Terms),
Self::_Custom(c) => Some(AuthType::_Custom(PrivOwnedStr(c.auth_type.as_str().into()))),
}
}
pub fn session(&self) -> Option<&str> {
match self {
Self::Password(x) => x.session.as_deref(),
Self::ReCaptcha(x) => x.session.as_deref(),
Self::EmailIdentity(x) => x.session.as_deref(),
Self::Msisdn(x) => x.session.as_deref(),
Self::Dummy(x) => x.session.as_deref(),
Self::RegistrationToken(x) => x.session.as_deref(),
Self::FallbackAcknowledgement(x) => Some(&x.session),
Self::Terms(x) => x.session.as_deref(),
Self::_Custom(x) => x.session.as_deref(),
}
}
pub fn data(&self) -> Cow<'_, JsonObject> {
fn serialize<T: Serialize>(obj: T) -> JsonObject {
match serde_json::to_value(obj).expect("auth data serialization to succeed") {
JsonValue::Object(obj) => obj,
_ => panic!("all auth data variants must serialize to objects"),
}
}
match self {
Self::Password(x) => Cow::Owned(serialize(Password {
identifier: x.identifier.clone(),
password: x.password.clone(),
session: None,
})),
Self::ReCaptcha(x) => {
Cow::Owned(serialize(ReCaptcha { response: x.response.clone(), session: None }))
}
Self::EmailIdentity(x) => Cow::Owned(serialize(EmailIdentity {
thirdparty_id_creds: x.thirdparty_id_creds.clone(),
session: None,
})),
Self::Msisdn(x) => Cow::Owned(serialize(Msisdn {
thirdparty_id_creds: x.thirdparty_id_creds.clone(),
session: None,
})),
Self::RegistrationToken(x) => {
Cow::Owned(serialize(RegistrationToken { token: x.token.clone(), session: None }))
}
Self::Dummy(_) | Self::FallbackAcknowledgement(_) | Self::Terms(_) => {
Cow::Owned(JsonObject::default())
}
Self::_Custom(c) => Cow::Borrowed(&c.extra),
}
}
}
impl fmt::Debug for AuthData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Password(inner) => inner.fmt(f),
Self::ReCaptcha(inner) => inner.fmt(f),
Self::EmailIdentity(inner) => inner.fmt(f),
Self::Msisdn(inner) => inner.fmt(f),
Self::Dummy(inner) => inner.fmt(f),
Self::RegistrationToken(inner) => inner.fmt(f),
Self::FallbackAcknowledgement(inner) => inner.fmt(f),
Self::Terms(inner) => inner.fmt(f),
Self::_Custom(inner) => inner.fmt(f),
}
}
}
impl<'de> Deserialize<'de> for AuthData {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let json = Box::<RawJsonValue>::deserialize(deserializer)?;
#[derive(Deserialize)]
struct ExtractType<'a> {
#[serde(borrow, rename = "type")]
auth_type: Option<Cow<'a, str>>,
}
let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
.map_err(de::Error::custom)?
.auth_type;
match auth_type.as_deref() {
Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
Some("m.login.registration_token") => {
from_raw_json_value(&json).map(Self::RegistrationToken)
}
Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
Some(_) => from_raw_json_value(&json).map(Self::_Custom),
}
}
}
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[non_exhaustive]
pub enum AuthType {
#[ruma_enum(rename = "m.login.password")]
Password,
#[ruma_enum(rename = "m.login.recaptcha")]
ReCaptcha,
#[ruma_enum(rename = "m.login.email.identity")]
EmailIdentity,
#[ruma_enum(rename = "m.login.msisdn")]
Msisdn,
#[ruma_enum(rename = "m.login.sso")]
Sso,
#[ruma_enum(rename = "m.login.dummy")]
Dummy,
#[ruma_enum(rename = "m.login.registration_token")]
RegistrationToken,
#[ruma_enum(rename = "m.login.terms")]
Terms,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
#[derive(Clone, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.password")]
pub struct Password {
pub identifier: UserIdentifier,
pub password: String,
pub session: Option<String>,
}
impl Password {
pub fn new(identifier: UserIdentifier, password: String) -> Self {
Self { identifier, password, session: None }
}
}
impl fmt::Debug for Password {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { identifier, password: _, session } = self;
f.debug_struct("Password")
.field("identifier", identifier)
.field("session", session)
.finish_non_exhaustive()
}
}
#[derive(Clone, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.recaptcha")]
pub struct ReCaptcha {
pub response: String,
pub session: Option<String>,
}
impl ReCaptcha {
pub fn new(response: String) -> Self {
Self { response, session: None }
}
}
impl fmt::Debug for ReCaptcha {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { response: _, session } = self;
f.debug_struct("ReCaptcha").field("session", session).finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.email.identity")]
pub struct EmailIdentity {
#[serde(rename = "threepid_creds")]
pub thirdparty_id_creds: ThirdpartyIdCredentials,
pub session: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.msisdn")]
pub struct Msisdn {
#[serde(rename = "threepid_creds")]
pub thirdparty_id_creds: ThirdpartyIdCredentials,
pub session: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.dummy")]
pub struct Dummy {
pub session: Option<String>,
}
impl Dummy {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Clone, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.registration_token")]
pub struct RegistrationToken {
pub token: String,
pub session: Option<String>,
}
impl RegistrationToken {
pub fn new(token: String) -> Self {
Self { token, session: None }
}
}
impl fmt::Debug for RegistrationToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { token: _, session } = self;
f.debug_struct("RegistrationToken").field("session", session).finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
pub struct FallbackAcknowledgement {
pub session: String,
}
impl FallbackAcknowledgement {
pub fn new(session: String) -> Self {
Self { session }
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[serde(tag = "type", rename = "m.login.terms")]
pub struct Terms {
pub session: Option<String>,
}
impl Terms {
pub fn new() -> Self {
Self::default()
}
}
#[doc(hidden)]
#[derive(Clone, Deserialize, Serialize)]
#[non_exhaustive]
pub struct CustomAuthData {
#[serde(rename = "type")]
auth_type: String,
session: Option<String>,
#[serde(flatten)]
extra: JsonObject,
}
impl fmt::Debug for CustomAuthData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { auth_type, session, extra: _ } = self;
f.debug_struct("CustomAuthData")
.field("auth_type", auth_type)
.field("session", session)
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[allow(clippy::exhaustive_enums)]
pub enum UserIdentifier {
UserIdOrLocalpart(String),
Email {
address: String,
},
Msisdn {
number: String,
},
PhoneNumber {
country: String,
phone: String,
},
#[doc(hidden)]
_CustomThirdParty(CustomThirdPartyId),
}
impl UserIdentifier {
pub fn third_party_id(medium: Medium, address: String) -> Self {
match medium {
Medium::Email => Self::Email { address },
Medium::Msisdn => Self::Msisdn { number: address },
_ => Self::_CustomThirdParty(CustomThirdPartyId { medium, address }),
}
}
pub fn as_third_party_id(&self) -> Option<(&Medium, &str)> {
match self {
Self::Email { address } => Some((&Medium::Email, address)),
Self::Msisdn { number } => Some((&Medium::Msisdn, number)),
Self::_CustomThirdParty(CustomThirdPartyId { medium, address }) => {
Some((medium, address))
}
_ => None,
}
}
}
impl From<OwnedUserId> for UserIdentifier {
fn from(id: OwnedUserId) -> Self {
Self::UserIdOrLocalpart(id.into())
}
}
impl From<&OwnedUserId> for UserIdentifier {
fn from(id: &OwnedUserId) -> Self {
Self::UserIdOrLocalpart(id.as_str().to_owned())
}
}
#[doc(hidden)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
#[non_exhaustive]
pub struct CustomThirdPartyId {
medium: Medium,
address: String,
}
#[doc(hidden)]
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[non_exhaustive]
pub struct IncomingCustomThirdPartyId {
medium: Medium,
address: String,
}
#[derive(Clone, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
pub struct ThirdpartyIdCredentials {
pub sid: OwnedSessionId,
pub client_secret: OwnedClientSecret,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_server: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_access_token: Option<String>,
}
impl ThirdpartyIdCredentials {
pub fn new(sid: OwnedSessionId, client_secret: OwnedClientSecret) -> Self {
Self { sid, client_secret, id_server: None, id_access_token: None }
}
}
impl fmt::Debug for ThirdpartyIdCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { sid, client_secret: _, id_server, id_access_token } = self;
f.debug_struct("ThirdpartyIdCredentials")
.field("sid", sid)
.field("id_server", id_server)
.field("id_access_token", id_access_token)
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
pub struct UiaaInfo {
pub flows: Vec<AuthFlow>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub completed: Vec<AuthType>,
pub params: Box<RawJsonValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<String>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub auth_error: Option<StandardErrorBody>,
}
impl UiaaInfo {
pub fn new(flows: Vec<AuthFlow>, params: Box<RawJsonValue>) -> Self {
Self { flows, completed: Vec::new(), params, session: None, auth_error: None }
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
pub struct AuthFlow {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stages: Vec<AuthType>,
}
impl AuthFlow {
pub fn new(stages: Vec<AuthType>) -> Self {
Self { stages }
}
}
#[derive(Clone, Debug)]
#[allow(clippy::exhaustive_enums)]
pub enum UiaaResponse {
AuthResponse(UiaaInfo),
MatrixError(MatrixError),
}
impl fmt::Display for UiaaResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
Self::MatrixError(err) => write!(f, "{err}"),
}
}
}
impl From<MatrixError> for UiaaResponse {
fn from(error: MatrixError) -> Self {
Self::MatrixError(error)
}
}
impl EndpointError for UiaaResponse {
fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
if response.status() == http::StatusCode::UNAUTHORIZED {
if let Ok(uiaa_info) = from_json_slice(response.body().as_ref()) {
return Self::AuthResponse(uiaa_info);
}
}
Self::MatrixError(MatrixError::from_http_response(response))
}
}
impl std::error::Error for UiaaResponse {}
impl OutgoingResponse for UiaaResponse {
fn try_into_http_response<T: Default + BufMut>(
self,
) -> Result<http::Response<T>, IntoHttpError> {
match self {
UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
.header(http::header::CONTENT_TYPE, "application/json")
.status(&http::StatusCode::UNAUTHORIZED)
.body(ruma_common::serde::json_to_buf(&authentication_info)?)
.map_err(Into::into),
UiaaResponse::MatrixError(error) => error.try_into_http_response(),
}
}
}