diff --git a/CMakeLists.txt b/CMakeLists.txt index aa76ba0..287c545 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ add_library( # Sets the name of the library. ${SOURCES} ) SET(GCC_COVERAGE_COMPILE_FLAGS "-Wall -fprofile-arcs -ftest-coverage -g -O0") -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GCC_COVERAGE_COMPILE_FLAGS}") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GCC_COVERAGE_COMPILE_FLAGS} -pthread") SET(GCC_COVERAGE_LINK_FLAGS "-lgcov --coverage") #SET(GCC_COVERAGE_LINK_FLAGS "-lclang_rt.profile_osx -L/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/10.0.0/lib/darwin") diff --git a/datastore.cpp b/datastore.cpp index 2f5b0b2..8869289 100644 --- a/datastore.cpp +++ b/datastore.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include "datastore.hpp" #include "utils.hpp" #include "vendor/nlohmann/json.hpp" @@ -33,10 +34,11 @@ using namespace error; static string FilePath(const string& file_root, const string& suffix); static Result FileLoad(const string& file_path); -static Error FileStore(bool paused, const string& file_path, const json& json); +static Error FileStore(int transaction_depth, const string& file_path, const json& json); Datastore::Datastore() - : initialized_(false), json_(json::object()), paused_(false) { + : initialized_(false), explicit_lock_(mutex_, std::defer_lock), + transaction_depth_(0), json_(json::object()) { } Error Datastore::Init(const string& file_root, const string& suffix) { @@ -55,8 +57,8 @@ Error Datastore::Init(const string& file_root, const string& suffix) { Error Datastore::Reset(const string& file_path, json new_value) { SYNCHRONIZE(mutex_); - paused_ = false; - if (auto err = FileStore(paused_, file_path, new_value)) { + transaction_depth_ = 0; + if (auto err = FileStore(transaction_depth_, file_path, new_value)) { return PassError(err); } json_ = new_value; @@ -73,25 +75,40 @@ Error Datastore::Reset(json new_value) { return PassError(Reset(file_path_, new_value)); } -bool Datastore::PauseWrites() { +void Datastore::BeginTransaction() { + // We only acquire a non-local lock if we're starting an outermost transaction. SYNCHRONIZE(mutex_); - auto was_paused = paused_; - paused_ = true; - return !was_paused; + // We got a local lock, so we know there's no transaction in progress in any other thread. + if (transaction_depth_ == 0) { + explicit_lock_.lock(); + } + transaction_depth_++; } -Error Datastore::UnpauseWrites(bool commit) { +Error Datastore::EndTransaction(bool commit) { SYNCHRONIZE(mutex_); MUST_BE_INITIALIZED; - if (!paused_) { + if (transaction_depth_ <= 0) { + assert(false); return nullerr; } - paused_ = false; + + transaction_depth_--; + + if (transaction_depth_ > 0) { + // This was an inner transaction and there's nothing more to do. + return nullerr; + } + + // We need to release the explicit lock on exit from ths function, no matter what. + // We will "adopt" the lock into this lock_guard to ensure the unlock happens when it goes out of scope. + std::lock_guard> lock_releaser(explicit_lock_, std::adopt_lock); + if (commit) { - return PassError(FileStore(paused_, file_path_, json_)); + return PassError(FileStore(transaction_depth_, file_path_, json_)); } - // Revert to what's on disk + // We're rolling back -- revert to what's on disk auto res = FileLoad(file_path_); if (!res) { return PassError(res.error()); @@ -110,7 +127,7 @@ Error Datastore::Set(const json::json_pointer& p, json v) { SYNCHRONIZE(mutex_); MUST_BE_INITIALIZED; json_[p] = v; - return PassError(FileStore(paused_, file_path_, json_)); + return PassError(FileStore(transaction_depth_, file_path_, json_)); } static string FilePath(const string& file_root, const string& suffix) { @@ -186,8 +203,8 @@ static Result FileLoad(const string& file_path) { return json; } -static Error FileStore(bool paused, const string& file_path, const json& json) { - if (paused) { +static Error FileStore(int transaction_depth, const string& file_path, const json& json) { + if (transaction_depth > 0) { return nullerr; } diff --git a/datastore.hpp b/datastore.hpp index 6d49d30..2182833 100644 --- a/datastore.hpp +++ b/datastore.hpp @@ -35,7 +35,7 @@ class Datastore { using json = nlohmann::json; public: - enum DatastoreGetError { + enum class DatastoreGetError { kNotFound = 1, kTypeMismatch, kDatastoreUninitialized @@ -62,12 +62,16 @@ class Datastore { /// Init() must have already been called, successfully. error::Error Reset(json new_value); - /// Stops writing of updates to disk until UnpauseWrites is called. - /// Returns false if writing was already paused (so this call did nothing). - bool PauseWrites(); - /// Unpauses writing. If commit is true, it writes the changes immediately; if false - /// it discards the changes. - error::Error UnpauseWrites(bool commit); + /// Locks the read/write mutex and stops writing of updates to disk until + /// EndTransaction is called. Transactions are re-enterable, but not nested. + /// NOTE: Failing to call EndTransaction will result in undefined behaviour. + void BeginTransaction(); + /// Ends an ongoing transaction writing. If commit is true, it writes the changes + /// immediately; if false it discards the changes. + /// Committing or rolling back inner transactions does nothing. Any errors during + /// inner transactions that require the outermost transaction to be rolled back must + /// be handled by the caller. + error::Error EndTransaction(bool commit); /// Returns the value, or an error indicating the failure reason. template @@ -79,11 +83,11 @@ class Datastore { SYNCHRONIZE_BLOCK(mutex_) { // Not using MUST_BE_INITIALIZED so we don't need it in the header. if (!initialized_) { - return nonstd::make_unexpected(kDatastoreUninitialized); + return nonstd::make_unexpected(DatastoreGetError::kDatastoreUninitialized); } if (p.empty() || !json_.contains(p)) { - return nonstd::make_unexpected(kNotFound); + return nonstd::make_unexpected(DatastoreGetError::kNotFound); } val = json_.at(p).get(); @@ -91,11 +95,11 @@ class Datastore { return val; } catch (json::type_error&) { - return nonstd::make_unexpected(kTypeMismatch); + return nonstd::make_unexpected(DatastoreGetError::kTypeMismatch); } catch (json::out_of_range&) { // This should be avoided by the explicit check above. But we'll be safe. - return nonstd::make_unexpected(kNotFound); + return nonstd::make_unexpected(DatastoreGetError::kNotFound); } } @@ -112,11 +116,14 @@ class Datastore { error::Error Reset(const std::string& file_path, json new_value); private: - mutable std::recursive_mutex mutex_; bool initialized_; + + mutable std::recursive_mutex mutex_; + std::unique_lock explicit_lock_; + int transaction_depth_; + std::string file_path_; json json_; - bool paused_; }; } // namespace psicash diff --git a/datastore_test.cpp b/datastore_test.cpp index 599f812..8aa3c3f 100644 --- a/datastore_test.cpp +++ b/datastore_test.cpp @@ -19,6 +19,7 @@ #include #include +#include #include "gtest/gtest.h" #include "test_helpers.hpp" @@ -142,7 +143,7 @@ TEST_F(TestDatastore, Reset) // First Get without calling Init; should get "not initialized" error got = ds->Get(k); ASSERT_FALSE(got); - ASSERT_EQ(got.error(), psicash::Datastore::kDatastoreUninitialized); + ASSERT_EQ(got.error(), psicash::Datastore::DatastoreGetError::kDatastoreUninitialized); // Then initialize and try again err = ds->Init(temp_dir.c_str(), ds_suffix); @@ -151,7 +152,7 @@ TEST_F(TestDatastore, Reset) // Key should not be found, since we haven't set it got = ds->Get(k); ASSERT_FALSE(got); - ASSERT_EQ(got.error(), psicash::Datastore::kNotFound); + ASSERT_EQ(got.error(), psicash::Datastore::DatastoreGetError::kNotFound); // Set it err = ds->Set(k, want); @@ -178,7 +179,7 @@ TEST_F(TestDatastore, Reset) got = ds->Get(k); ASSERT_FALSE(got); - ASSERT_EQ(got.error(), psicash::Datastore::kNotFound); + ASSERT_EQ(got.error(), psicash::Datastore::DatastoreGetError::kNotFound); delete ds; @@ -202,7 +203,7 @@ TEST_F(TestDatastore, Reset) } -TEST_F(TestDatastore, WritePause) +TEST_F(TestDatastore, Transactions) { auto temp_dir = GetTempDir(); @@ -211,38 +212,38 @@ TEST_F(TestDatastore, WritePause) ASSERT_FALSE(err); // This should persist - auto pause_want1 = "/pause_want1"_json_pointer; - err = ds->Set(pause_want1, pause_want1.to_string()); + auto trans_want1 = "/trans_want1"_json_pointer; + err = ds->Set(trans_want1, trans_want1.to_string()); ASSERT_FALSE(err); // This should persist, as we're committing - ds->PauseWrites(); - auto pause_want2 = "/pause_want2"_json_pointer; - err = ds->Set(pause_want2, pause_want2.to_string()); + ds->BeginTransaction(); + auto trans_want2 = "/trans_want2"_json_pointer; + err = ds->Set(trans_want2, trans_want2.to_string()); ASSERT_FALSE(err); - err = ds->UnpauseWrites(/*commit=*/true); + err = ds->EndTransaction(/*commit=*/true); ASSERT_FALSE(err); // This should NOT persist, as we're rolling back - ds->PauseWrites(); - auto pause_want3 = "/pause_want3"_json_pointer; - err = ds->Set(pause_want3, pause_want3.to_string()); + ds->BeginTransaction(); + auto trans_want3 = "/trans_want3"_json_pointer; + err = ds->Set(trans_want3, trans_want3.to_string()); ASSERT_FALSE(err); - err = ds->UnpauseWrites(/*commit=*/false); + err = ds->EndTransaction(/*commit=*/false); ASSERT_FALSE(err); // Another committed value, to make sure the order of things doesn't matter - ds->PauseWrites(); - auto pause_want4 = "/pause_want4"_json_pointer; - err = ds->Set(pause_want4, pause_want4.to_string()); + ds->BeginTransaction(); + auto trans_want4 = "/trans_want4"_json_pointer; + err = ds->Set(trans_want4, trans_want4.to_string()); ASSERT_FALSE(err); - err = ds->UnpauseWrites(/*commit=*/true); + err = ds->EndTransaction(/*commit=*/true); ASSERT_FALSE(err); // This should also NOT persist, since we're hitting the dtor - ds->PauseWrites(); - auto pause_want5 = "/pause_want5"_json_pointer; - err = ds->Set(pause_want5, pause_want5.to_string()); + ds->BeginTransaction(); + auto trans_want5 = "/trans_want5"_json_pointer; + err = ds->Set(trans_want5, trans_want5.to_string()); ASSERT_FALSE(err); // Close @@ -253,27 +254,160 @@ TEST_F(TestDatastore, WritePause) err = ds->Init(temp_dir.c_str(), ds_suffix); ASSERT_FALSE(err); - auto got = ds->Get(pause_want1); + auto got = ds->Get(trans_want1); ASSERT_TRUE(got); - ASSERT_EQ(*got, pause_want1.to_string()); + ASSERT_EQ(*got, trans_want1.to_string()); - got = ds->Get(pause_want2); + got = ds->Get(trans_want2); ASSERT_TRUE(got); - ASSERT_EQ(*got, pause_want2.to_string()); + ASSERT_EQ(*got, trans_want2.to_string()); - got = ds->Get(pause_want3); + got = ds->Get(trans_want3); ASSERT_FALSE(got); - got = ds->Get(pause_want4); + got = ds->Get(trans_want4); ASSERT_TRUE(got); - ASSERT_EQ(*got, pause_want4.to_string()); + ASSERT_EQ(*got, trans_want4.to_string()); - got = ds->Get(pause_want5); + got = ds->Get(trans_want5); ASSERT_FALSE(got); delete ds; } +TEST_F(TestDatastore, NestedTransactions) +{ + auto temp_dir = GetTempDir(); + + auto ds = new Datastore(); + auto err = ds->Init(temp_dir.c_str(), ds_suffix); + ASSERT_FALSE(err); + + // Nest then commit + + ds->BeginTransaction(); + auto trans_want1 = "/trans_want1"_json_pointer; + err = ds->Set(trans_want1, trans_want1.to_string()); + ASSERT_FALSE(err); + + ds->BeginTransaction(); + auto trans_want2 = "/trans_want2"_json_pointer; + err = ds->Set(trans_want2, trans_want2.to_string()); + ASSERT_FALSE(err); + + err = ds->EndTransaction(/*commit=*/true); + ASSERT_FALSE(err); + + auto trans_want3 = "/trans_want3"_json_pointer; + err = ds->Set(trans_want3, trans_want3.to_string()); + ASSERT_FALSE(err); + + err = ds->EndTransaction(/*commit=*/true); + ASSERT_FALSE(err); + + auto got = ds->Get(trans_want1); + ASSERT_TRUE(got); + ASSERT_EQ(*got, trans_want1.to_string()); + got = ds->Get(trans_want2); + ASSERT_TRUE(got); + ASSERT_EQ(*got, trans_want2.to_string()); + got = ds->Get(trans_want3); + ASSERT_TRUE(got); + ASSERT_EQ(*got, trans_want3.to_string()); + + // Nest then roll back + + ds->BeginTransaction(); + auto trans_want4 = "/trans_want4"_json_pointer; + err = ds->Set(trans_want4, trans_want4.to_string()); + ASSERT_FALSE(err); + + // Also set one of the previous set values + err = ds->Set(trans_want3, "nope"); + ASSERT_FALSE(err); + + ds->BeginTransaction(); + auto trans_want5 = "/trans_want5"_json_pointer; + err = ds->Set(trans_want5, trans_want5.to_string()); + ASSERT_FALSE(err); + + // We're going to commit the inner transaction... + err = ds->EndTransaction(/*commit=*/true); + ASSERT_FALSE(err); + + auto trans_want6 = "/trans_want6"_json_pointer; + err = ds->Set(trans_want6, trans_want6.to_string()); + ASSERT_FALSE(err); + + // ...and roll back the outer transaction + err = ds->EndTransaction(/*commit=*/false); + ASSERT_FALSE(err); + + got = ds->Get(trans_want3); + ASSERT_TRUE(got); + ASSERT_EQ(*got, trans_want3.to_string()); + got = ds->Get(trans_want4); + ASSERT_FALSE(got) << *got; + got = ds->Get(trans_want5); + ASSERT_FALSE(got); + got = ds->Get(trans_want6); + ASSERT_FALSE(got); +} + +TEST_F(TestDatastore, TransactionRaceConditionBug) +{ + // Before we added "transactions" to the datastore, it only provided the ability to + // "pause" writing. But this still left the possibility that a set of updates that + // are only coherent together (like setting "is account" and the "account username") + // could be read when only partially updated (each individual Set is threadsafe, but + // multiple of them could be read separately). + + Datastore ds; + auto err = ds.Init(GetTempDir().c_str(), ds_suffix); + ASSERT_FALSE(err); + + const auto k = "/k"_json_pointer; + const string k_want1 = "kv"; + err = ds.Set(k, k_want1); + ASSERT_FALSE(err); + + const auto j = "/j"_json_pointer; + const string j_want1 = "jv"; + err = ds.Set(j, j_want1); + ASSERT_FALSE(err); + + const string k_want2 = "kv2", j_want2 = "jv2"; + + std::thread t([&](){ + ds.BeginTransaction(); + + err = ds.Set(k, k_want2); + ASSERT_FALSE(err); + + // Give the main-thread code a chance to read k + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + err = ds.Set(j, j_want2); + ASSERT_FALSE(err); + + ds.EndTransaction(true); + }); + + // Give the thread a chance to start and set k + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // With proper transaction isolation, this Get should wait on the mutex until the thread is done + auto k_got = ds.Get(k); + ASSERT_TRUE(k_got); + auto j_got = ds.Get(j); + ASSERT_TRUE(j_got) << (int)j_got.error(); + // If we have a race condition, k will have been updated, but j won't be + ASSERT_EQ(*k_got, k_want2); + ASSERT_EQ(*j_got, j_want2) << "transaction isolation fail!"; + + t.join(); +} + TEST_F(TestDatastore, SetSimple) { Datastore ds; @@ -416,15 +550,15 @@ TEST_F(TestDatastore, TypeMismatch) // It's an error to set one type and then try to get another auto got_fail_1 = ds.Get(wantStringKey); ASSERT_FALSE(got_fail_1); - ASSERT_EQ(got_fail_1.error(), psicash::Datastore::kTypeMismatch); + ASSERT_EQ(got_fail_1.error(), psicash::Datastore::DatastoreGetError::kTypeMismatch); auto got_fail_2 = ds.Get(wantIntKey); ASSERT_FALSE(got_fail_2); - ASSERT_EQ(got_fail_2.error(), psicash::Datastore::kTypeMismatch); + ASSERT_EQ(got_fail_2.error(), psicash::Datastore::DatastoreGetError::kTypeMismatch); auto got_fail_3 = ds.Get(wantStringKey); ASSERT_FALSE(got_fail_3); - ASSERT_EQ(got_fail_3.error(), psicash::Datastore::kTypeMismatch); + ASSERT_EQ(got_fail_3.error(), psicash::Datastore::DatastoreGetError::kTypeMismatch); auto got_fail_4 = ds.Get(wantBoolKey); //ASSERT_FALSE(got_fail_4); // NOTE: This doesn't actually fail. There must be a successful implicit conversion. @@ -493,7 +627,7 @@ TEST_F(TestDatastore, GetNotFound) // Bad key auto nope = ds.Get("/nope"_json_pointer); ASSERT_FALSE(nope); - ASSERT_EQ(nope.error(), psicash::Datastore::kNotFound); + ASSERT_EQ(nope.error(), psicash::Datastore::DatastoreGetError::kNotFound); } TEST_F(TestDatastore, GetFullDS) diff --git a/psicash.cpp b/psicash.cpp index 8ac6e02..fa4050a 100644 --- a/psicash.cpp +++ b/psicash.cpp @@ -146,12 +146,12 @@ Error PsiCash::MigrateTrackerTokens(const map& tokens) { // leave expiry null } - UserData::WritePauser pauser(*user_data_); + UserData::Transaction transaction(*user_data_); // Ignoring return values while writing is paused. // Blow away any user state, as the newly migrated tokens are overwriting it. (void)ResetUser(); (void)user_data_->SetAuthTokens(auth_tokens, /*is_account=*/false, /*account_username=*/""); - if (auto err = pauser.Commit()) { + if (auto err = transaction.Commit()) { return WrapError(err, "user data write failed"); } return nullerr; @@ -710,11 +710,11 @@ Result PsiCash::NewTracker() { } // Set our new data in a single write. - UserData::WritePauser pauser(*user_data_); + UserData::Transaction transaction(*user_data_); (void)user_data_->SetIsLoggedOutAccount(false); (void)user_data_->SetAuthTokens(auth_tokens, /*is_account=*/false, /*account_username=*/""); (void)user_data_->SetBalance(0); - if (auto err = pauser.Commit()) { + if (auto err = transaction.Commit()) { return WrapError(err, "user data write failed"); } @@ -846,7 +846,7 @@ Result PsiCash::RefreshState( try { // We're going to be setting a bunch of UserData values, so let's wait until we're done // to write them all to disk. - UserData::WritePauser pauser(*user_data_); + UserData::Transaction transaction(*user_data_); auto j = json::parse(result->body); @@ -926,7 +926,7 @@ Result PsiCash::RefreshState( (void)user_data_->DeleteUserData(true); } - if (auto err = pauser.Commit()) { + if (auto err = transaction.Commit()) { return WrapError(err, "UserData write failed"); } } @@ -1010,7 +1010,7 @@ Result PsiCash::NewExpiringPurchase( // Set our new data in a single write. // Note that any early return will cause updates to roll back. - UserData::WritePauser pauser(*user_data_); + UserData::Transaction transaction(*user_data_); // Balance is present for all non-error responses if (j.at("Balance").is_number_integer()) { @@ -1039,7 +1039,7 @@ Result PsiCash::NewExpiringPurchase( } - if (auto err = pauser.Commit()) { + if (auto err = transaction.Commit()) { return WrapError(err, "UserData write failed"); } } @@ -1207,10 +1207,10 @@ error::Result PsiCash::AccountLogin( } // Set our new data in a single write. - UserData::WritePauser pauser(*user_data_); + UserData::Transaction transaction(*user_data_); (void)user_data_->SetIsLoggedOutAccount(false); (void)user_data_->SetAuthTokens(auth_tokens, /*is_account=*/true, /*utf8_username=*/utf8_username); - if (auto err = pauser.Commit()) { + if (auto err = transaction.Commit()) { return WrapError(err, "user data write failed"); } diff --git a/psicash_test.cpp b/psicash_test.cpp index b60f35c..0671882 100644 --- a/psicash_test.cpp +++ b/psicash_test.cpp @@ -1937,7 +1937,7 @@ TEST_F(TestPsiCash, RefreshStateOffline) { request_attempted = true; return HTTPResult(); }; - const MakeHTTPRequestFn errorHTTPRequester = [&request_attempted](const HTTPParams&) -> HTTPResult { + const MakeHTTPRequestFn errorHTTPRequester = [](const HTTPParams&) -> HTTPResult { auto res = HTTPResult(); res.code = HTTPResult::RECOVERABLE_ERROR; res.error = "test"; @@ -2133,10 +2133,10 @@ TEST_F(TestPsiCash, NewExpiringPurchase) { ASSERT_EQ(pc.Balance(), initial_balance); } -TEST_F(TestPsiCash, NewExpiringPurchasePauserCommitBug) { +TEST_F(TestPsiCash, NewExpiringPurchaseTransactionCommitBug) { // Bug test: When a kHTTPStatusTooManyRequests response (or any non-success, but // especially that one) was received, the updated balance received in the response - // would be written to the datastore, but the WritePauser would not be committed, so + // would be written to the datastore, but the Transaction would not be committed, so // the change would be lost and the UI wouldn't update until a RefreshState request // was made. diff --git a/userdata.cpp b/userdata.cpp index aab45b2..d55aa2f 100644 --- a/userdata.cpp +++ b/userdata.cpp @@ -31,7 +31,6 @@ namespace psicash { constexpr int kCurrentDatastoreVersion = 2; // Datastore keys -static constexpr const char* VERSION = "v"; static auto kVersionPtr = "/v"_json_pointer; // // Instance-specific data keys @@ -148,11 +147,11 @@ error::Error UserData::Clear() { } error::Error UserData::DeleteUserData(bool isLoggedOutAccount) { - WritePauser pauser(*this); + Transaction transaction(*this); // Not checking return values, since writing is paused. (void)datastore_.Set(kUserPtr, json::object()); (void)SetIsLoggedOutAccount(isLoggedOutAccount); - return PassError(pauser.Commit()); + return PassError(transaction.Commit()); } std::string UserData::GetInstanceID() const { @@ -257,14 +256,14 @@ AuthTokens UserData::GetAuthTokens() const { } error::Error UserData::SetAuthTokens(const AuthTokens& v, bool is_account, const std::string& utf8_username) { - WritePauser pauser(*this); + Transaction transaction(*this); // Not checking errors while paused, as there's no error that can occur. json json_tokens; to_json(json_tokens, v); (void)datastore_.Set(kAuthTokensPtr, json_tokens); (void)datastore_.Set(kIsAccountPtr, is_account); (void)datastore_.Set(kAccountUsernamePtr, utf8_username); - return PassError(pauser.Commit()); // write + return PassError(transaction.Commit()); // write } error::Error UserData::CullAuthTokens(const std::map& valid_tokens) { @@ -418,12 +417,12 @@ error::Error UserData::AddPurchase(const Purchase& v) { } } - // Pause to set Purchases and LastTransactionID in one write - WritePauser pauser(*this); + // Use transaction to set Purchases and LastTransactionID in one write + Transaction transaction(*this); // These don't write, so have no meaningful return (void)SetPurchases(purchases); (void)SetLastTransactionID(v.id); - return PassError(pauser.Commit()); // write + return PassError(transaction.Commit()); // write } void UserData::UpdatePurchaseLocalTimeExpiry(Purchase& purchase) const { diff --git a/userdata.hpp b/userdata.hpp index c3dc8b4..77e765b 100644 --- a/userdata.hpp +++ b/userdata.hpp @@ -67,19 +67,24 @@ class UserData { /// Init() must have already been called, successfully. error::Error Clear(); - /// Used to pause and resume datastore file writing. - /// WritePausers can be nested -- inner instances will do nothing. - class WritePauser { + /// Used to wrap datastore "transactions" (paused writing, mutexed access). + /// Transaction can be nested -- inner instances will do nothing. + class Transaction { public: - WritePauser(UserData& user_data) : actually_paused_(false), user_data_( - user_data) { actually_paused_ = user_data_.datastore_.PauseWrites(); }; - ~WritePauser() { if (actually_paused_) { (void)Rollback(); } } - error::Error Commit() { return Unpause(true); } - error::Error Rollback() { return Unpause(false); } + Transaction(UserData& user_data) : user_data_(user_data), in_transaction_(false) + { user_data_.datastore_.BeginTransaction(); in_transaction_ = true; }; + ~Transaction() { if (in_transaction_) { (void)Rollback(); } } + error::Error Commit() { return End(true); } + error::Error Rollback() { return End(false); } private: - error::Error Unpause(bool commit) { auto p = actually_paused_; actually_paused_ = false; if (p) { return user_data_.datastore_.UnpauseWrites(commit); } return error::nullerr; } - bool actually_paused_; + error::Error End(bool commit) { + if (in_transaction_) { + in_transaction_ = false; return user_data_.datastore_.EndTransaction(commit); + } + return error::nullerr; + } UserData& user_data_; + bool in_transaction_; }; public: diff --git a/userdata_test.cpp b/userdata_test.cpp index f71142a..1661afd 100644 --- a/userdata_test.cpp +++ b/userdata_test.cpp @@ -609,3 +609,108 @@ TEST_F(TestUserData, Locale) v = ud.GetLocale(); ASSERT_EQ(v, ""); } + +TEST_F(TestUserData, Transaction) +{ + UserData ud; + auto err = ud.Init(GetTempDir().c_str(), dev); + ASSERT_FALSE(err); + + // We're only using specific accessors for easy testing + err = ud.SetLocale("1"); + ASSERT_FALSE(err); + auto v = ud.GetLocale(); + ASSERT_EQ(v, "1"); + + { + UserData::Transaction udt(ud); + err = ud.SetLocale("2"); + ASSERT_FALSE(err); + auto v = ud.GetLocale(); + ASSERT_EQ(v, "2"); + // dtor rollback + } + v = ud.GetLocale(); + ASSERT_EQ(v, "1"); + + { + UserData::Transaction udt(ud); + err = ud.SetLocale("3"); + ASSERT_FALSE(err); + auto err = udt.Commit(); + ASSERT_FALSE(err); + } + v = ud.GetLocale(); + ASSERT_EQ(v, "3"); + + { + UserData::Transaction udt(ud); + err = ud.SetLocale("4"); + ASSERT_FALSE(err); + auto err = udt.Rollback(); + ASSERT_FALSE(err); + } + v = ud.GetLocale(); + ASSERT_EQ(v, "3"); + + { + UserData::Transaction udt(ud); + err = ud.SetLocale("5"); + ASSERT_FALSE(err); + auto err = udt.Commit(); + ASSERT_FALSE(err); + err = udt.Rollback(); // does nothing + ASSERT_FALSE(err); + err = udt.Commit(); // does nothing + ASSERT_FALSE(err); + } + v = ud.GetLocale(); + ASSERT_EQ(v, "5"); + + { + UserData::Transaction udt(ud); + auto err = udt.Rollback(); + ASSERT_FALSE(err); + // Modify _after_ rollback + err = ud.SetLocale("6"); + ASSERT_FALSE(err); + // Extra rollback should do nothing + err = udt.Rollback(); + ASSERT_FALSE(err); + } + v = ud.GetLocale(); + ASSERT_EQ(v, "6"); + + { + UserData::Transaction udt(ud); + + { + UserData::Transaction udt(ud); + + { + UserData::Transaction udt(ud); + err = ud.SetLocale("7"); + ASSERT_FALSE(err); + // inner commit does nothing + auto err = udt.Commit(); + ASSERT_FALSE(err); + } + + auto v = ud.GetIsAccount(); + ASSERT_EQ(v, false); + err = ud.SetIsAccount(true); + ASSERT_FALSE(err); + // inner rollback does nothing + auto err = udt.Rollback(); + ASSERT_FALSE(err); + } + + // We have committed one inner transaction and rolled back the another, but they have no effect on the outer. + // Now we're commit the outer transaction. + udt.Commit(); + } + v = ud.GetLocale(); + ASSERT_EQ(v, "7"); + auto b = ud.GetIsAccount(); + ASSERT_EQ(b, true); +} diff --git a/utils.hpp b/utils.hpp index 59653aa..4567573 100644 --- a/utils.hpp +++ b/utils.hpp @@ -50,7 +50,7 @@ std::string Stringer(const T& value, const Args& ... args) { /// } #define SYNCHRONIZE_BLOCK(m) for(std::unique_lock lk(m); lk; lk.unlock()) /// Synchronize the current scope using the given mutex. -#define SYNCHRONIZE(m) std::unique_lock synchronize_lock(m) +#define SYNCHRONIZE(m) std::lock_guard synchronize_lock(m) /// Tests if the given filepath+name exists. bool FileExists(const std::string& filename);