Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions driver/src/action/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
bson::Timestamp,
client::options::{SessionOptions, TransactionOptions},
error::Result,
Client,
Expand Down
19 changes: 13 additions & 6 deletions driver/src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::{
time::Duration,
};

use crate::bson::UuidRepresentation;
use derive_where::derive_where;
use macro_magic::export_tokens;
use serde::{de::Unexpected, Deserialize, Deserializer, Serialize, Serializer};
Expand All @@ -35,7 +34,7 @@ use crate::options::Compressor;
#[cfg(test)]
use crate::srv::LookupHosts;
use crate::{
bson::{doc, Bson, Document},
bson::{doc, Bson, Document, Timestamp, UuidRepresentation},
client::auth::{AuthMechanism, Credential},
concern::{Acknowledgment, ReadConcern, WriteConcern},
error::{Error, ErrorKind, Result},
Expand Down Expand Up @@ -3044,19 +3043,27 @@ pub struct SessionOptions {
/// If true, all read operations performed using this client session will share the same
/// snapshot. Defaults to false.
pub snapshot: Option<bool>,

/// The snapshot time to use for a snapshot session. This option can only be set if `snapshot`
/// is set to true.
pub snapshot_time: Option<Timestamp>,
}

impl SessionOptions {
pub(crate) fn validate(&self) -> Result<()> {
if let (Some(causal_consistency), Some(snapshot)) = (self.causal_consistency, self.snapshot)
{
if causal_consistency && snapshot {
return Err(ErrorKind::InvalidArgument {
message: "snapshot and causal consistency are mutually exclusive".to_string(),
}
.into());
return Err(Error::invalid_argument(
"snapshot and causal consistency are mutually exclusive",
));
}
}
if self.snapshot_time.is_some() && self.snapshot != Some(true) {
return Err(Error::invalid_argument(
"cannot set snapshot_time without setting snapshot to true",
));
}
Ok(())
}
}
Expand Down
21 changes: 20 additions & 1 deletion driver/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use uuid::Uuid;
use crate::{
bson::{doc, spec::BinarySubtype, Binary, Bson, Document, Timestamp},
cmap::conn::PinnedConnectionHandle,
error::{Error, Result},
operation::Retryability,
options::{SessionOptions, TransactionOptions},
sdam::ServerInfo,
Expand Down Expand Up @@ -236,6 +237,7 @@ impl ClientSession {
) -> Self {
let timeout = client.inner.topology.watcher().logical_session_timeout();
let server_session = client.inner.session_pool.check_out(timeout).await;
let snapshot_time = options.as_ref().and_then(|o| o.snapshot_time);
Self {
drop_token: client.register_async_drop(),
client,
Expand All @@ -244,7 +246,7 @@ impl ClientSession {
is_implicit,
options,
transaction: Default::default(),
snapshot_time: None,
snapshot_time,
operation_time: None,
#[cfg(test)]
convenient_transaction_timeout: None,
Expand Down Expand Up @@ -306,6 +308,23 @@ impl ClientSession {
self.operation_time
}

/// The snapshot time for a snapshot session. This will return `None` if `snapshot_time` was not
/// provided and the server has not yet responded with a snapshot time. It is an error to call
/// this method on a non-snapshot session.
pub fn snapshot_time(&self) -> Result<Option<Timestamp>> {
if !self
.options
.as_ref()
.and_then(|o| o.snapshot)
.unwrap_or(false)
{
return Err(Error::invalid_argument(
"cannot access snapshot time on a non-snapshot session",
));
}
Ok(self.snapshot_time)
}

pub(crate) fn causal_consistency(&self) -> bool {
self.options()
.and_then(|opts| opts.causal_consistency)
Expand Down
49 changes: 48 additions & 1 deletion driver/src/test/spec/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use futures::TryStreamExt;
use futures_util::{future::try_join_all, FutureExt};

use crate::{
bson::{doc, Document},
bson::{doc, Document, Timestamp},
error::{ErrorKind, Result},
event::{
command::{CommandEvent, CommandStartedEvent},
Expand All @@ -21,8 +21,10 @@ use crate::{
get_client_options,
log_uncaptured,
server_version_gte,
server_version_lt,
spec::unified_runner::run_unified_tests,
topology_is_load_balanced,
topology_is_replica_set,
topology_is_sharded,
Event,
},
Expand Down Expand Up @@ -281,3 +283,48 @@ async fn no_cluster_time_in_sdam() {
// Assert that the cluster time hasn't changed
assert_eq!(cluster_time.as_ref(), start.command.get("$clusterTime"));
}

// Sessions prose test 21
#[tokio::test]
async fn snapshot_time_and_snapshot_false_disallowed() {
if server_version_lt(5, 0).await
|| !(topology_is_replica_set().await || topology_is_sharded().await)
{
log_uncaptured(
"skipping snapshot_time_and_snapshot_false_disallowed: requires 5.0+ replica set or \
sharded cluster",
);
return;
}

let client = Client::for_test().await;
let error = client
.start_session()
.snapshot(false)
.snapshot_time(Timestamp {
time: 0,
increment: 0,
})
.await
.unwrap_err();
assert!(matches!(*error.kind, ErrorKind::InvalidArgument { .. }));
}

// Sessions prose test 22
#[tokio::test]
async fn cannot_call_snapshot_time_on_non_snapshot_session() {
if server_version_lt(5, 0).await
|| !(topology_is_replica_set().await || topology_is_sharded().await)
{
log_uncaptured(
"skipping cannot_call_snapshot_time_on_non_snapshot_session: requires 5.0+ replica \
set or sharded cluster",
);
return;
}

let client = Client::for_test().await;
let session = client.start_session().snapshot(false).await.unwrap();
let error = session.snapshot_time().unwrap_err();
assert!(matches!(*error.kind, ErrorKind::InvalidArgument { .. }));
}
73 changes: 22 additions & 51 deletions driver/src/test/spec/unified_runner/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,74 +25,44 @@ mod wait;

use std::{fmt::Debug, ops::Deref};

use collection::{
Aggregate,
AssertCollectionExists,
AssertCollectionNotExists,
CreateCollection,
DropCollection,
};
use command::{CreateCommandCursor, RunCommand, RunCursorCommand};
use connection::{AssertNumberConnectionsCheckedOut, Close};
use count::{AssertEventCount, CountDocuments, Distinct, EstimatedDocumentCount};
use delete::{DeleteMany, DeleteOne};
use failpoint::{FailPointCommand, TargetedFailPoint};
use find::{
CreateFindCursor,
Find,
FindOne,
FindOneAndDelete,
FindOneAndReplace,
FindOneAndUpdate,
};
use futures::{future::BoxFuture, FutureExt};
use gridfs::{Delete, DeleteByName, Download, DownloadByName, RenameByName, Upload};
use index::{
AssertIndexExists,
AssertIndexNotExists,
CreateIndex,
DropIndex,
DropIndexes,
ListIndexNames,
ListIndexes,
};
use insert::{InsertMany, InsertOne};
use iteration::{IterateOnce, IterateUntilDocumentOrError};
use list::{ListCollectionNames, ListCollections, ListDatabaseNames, ListDatabases};
use rename::Rename;
use serde::{
de::{DeserializeOwned, Deserializer},
Deserialize,
};
use session::{
AssertDifferentLsidOnLastTwoCommands,
AssertSameLsidOnLastTwoCommands,
AssertSessionDirty,
AssertSessionNotDirty,
AssertSessionPinned,
AssertSessionTransactionState,
AssertSessionUnpinned,
EndSession,
};
use thread::{RunOnThread, WaitForThread};
use tokio::sync::Mutex;
use topology::{AssertTopologyType, RecordTopologyDescription};
use transaction::{AbortTransaction, CommitTransaction, StartTransaction, WithTransaction};
use update::{ReplaceOne, UpdateMany, UpdateOne};
use wait::{Wait, WaitForEvent, WaitForPrimaryChange};

use super::{results_match, Entity, ExpectError, TestCursor, TestFileEntity, TestRunner};

use crate::{
bson::{doc, Bson, Document},
error::{ErrorKind, Result},
options::ChangeStreamOptions,
};

use super::{results_match, Entity, ExpectError, TestCursor, TestFileEntity, TestRunner};

use bulk_write::*;
use collection::*;
use command::*;
use connection::*;
use count::*;
#[cfg(feature = "in-use-encryption")]
use csfle::*;
use delete::*;
use failpoint::*;
use find::*;
use gridfs::*;
use index::*;
use insert::*;
use iteration::*;
use list::*;
use rename::*;
use search_index::*;
use session::*;
use thread::*;
use topology::*;
use transaction::*;
use update::*;
use wait::*;

pub(crate) trait TestOperation: Debug + Send + Sync {
fn execute_test_runner_operation<'a>(
Expand Down Expand Up @@ -439,6 +409,7 @@ impl<'de> Deserialize<'de> for Operation {
"decrypt" => deserialize_op::<Decrypt>(definition.arguments),
"dropIndex" => deserialize_op::<DropIndex>(definition.arguments),
"dropIndexes" => deserialize_op::<DropIndexes>(definition.arguments),
"getSnapshotTime" => deserialize_op::<GetSnapshotTime>(definition.arguments),
s => Ok(Box::new(UnimplementedOperation {
_name: s.to_string(),
}) as Box<dyn TestOperation>),
Expand Down
25 changes: 25 additions & 0 deletions driver/src/test/spec/unified_runner/operation/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
bson::Bson,
client::session::TransactionState,
error::Result,
test::spec::unified_runner::{
Expand Down Expand Up @@ -201,3 +202,27 @@ impl TestOperation for AssertSessionNotDirty {
.boxed()
}
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub(super) struct GetSnapshotTime {}

impl TestOperation for GetSnapshotTime {
fn execute_entity_operation<'a>(
&'a self,
id: &'a str,
test_runner: &'a TestRunner,
) -> BoxFuture<'a, Result<Option<Entity>>> {
async move {
with_mut_session!(test_runner, id, |session| {
async move {
session
.snapshot_time()
.map(|option| option.map(|ts| Bson::Timestamp(ts).into()))
}
})
.await
}
.boxed()
}
}
39 changes: 37 additions & 2 deletions driver/src/test/spec/unified_runner/test_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use semver::Version;
use serde::{Deserialize, Deserializer};
use tokio::sync::oneshot;

use super::{results_match, ExpectedEvent, ObserveEvent, Operation};
use super::{results_match, EntityMap, ExpectedEvent, ObserveEvent, Operation};

#[cfg(feature = "bson-3")]
use crate::bson_compat::RawDocumentBufExt;
Expand All @@ -28,6 +28,8 @@ use crate::{
ReadConcern,
ReadPreference,
SelectionCriteria,
SessionOptions,
TransactionOptions,
WriteConcern,
},
serde_util,
Expand Down Expand Up @@ -375,7 +377,40 @@ pub(crate) struct Collection {
pub(crate) struct Session {
pub(crate) id: String,
pub(crate) client: String,
pub(crate) session_options: Option<crate::client::options::SessionOptions>,
pub(crate) session_options: Option<TestFileSessionOptions>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub(crate) struct TestFileSessionOptions {
default_transaction_options: Option<TransactionOptions>,
causal_consistency: Option<bool>,
snapshot: Option<bool>,
snapshot_time: Option<String>, // the id of the entity to be retrieved from the map
}

impl TestFileSessionOptions {
pub(crate) fn as_session_options(&self, entities: &EntityMap) -> SessionOptions {
let snapshot_time = match self.snapshot_time {
Some(ref id) => {
let entity = entities
.get(id)
.unwrap_or_else(|| panic!("missing entity for id {id}"));
let bson = entity.as_bson();
let timestamp = bson
.as_timestamp()
.unwrap_or_else(|| panic!("expected timestamp for id {id}, got {bson}"));
Some(timestamp)
}
None => None,
};
SessionOptions {
default_transaction_options: self.default_transaction_options.clone(),
causal_consistency: self.causal_consistency,
snapshot: self.snapshot,
snapshot_time,
}
}
}

#[derive(Debug, Deserialize)]
Expand Down
14 changes: 9 additions & 5 deletions driver/src/test/spec/unified_runner/test_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,15 @@ impl TestRunner {
TestFileEntity::Session(session) => {
let id = session.id.clone();
let client = self.get_client(&session.client).await;
let mut client_session = client
.start_session()
.with_options(session.session_options.clone())
.await
.unwrap();
let options = match session.session_options {
Some(ref options) => {
let entities = self.entities.read().await;
Some(options.as_session_options(&entities))
}
None => None,
};
let mut client_session =
client.start_session().with_options(options).await.unwrap();
if let Some(time) = &*self.cluster_time.read().await {
client_session.advance_cluster_time(time);
}
Expand Down
Loading