From c2cbb4e41843aff7b87ed9a35fb8a84fdb7d9795 Mon Sep 17 00:00:00 2001 From: sparqet <37338401+sparqet@users.noreply.github.com> Date: Mon, 25 Sep 2023 23:41:10 +0200 Subject: [PATCH 1/2] Feat/position pricing utils (#412) * implemented first function * added _get_price_impact_usd function * corrected build fails * fixed coding style * added third function * implemented all functions * fix use i128 * fix tests * remove previous test folder * add tests * implemented all tests skeleton * corrected coding style * fix internal function name * uncomment referral storage tests * fix requests * fixed merge bugs --- src/lib.cairo | 3 +- src/market/market_utils.cairo | 97 +++- src/pricing/error.cairo | 12 + src/pricing/position_pricing_utils.cairo | 498 +++++++++++++----- .../pricing/test_position_pricing_utils.cairo | 216 ++++++++ 5 files changed, 671 insertions(+), 155 deletions(-) create mode 100644 src/tests/pricing/test_position_pricing_utils.cairo diff --git a/src/lib.cairo b/src/lib.cairo index e9696c19..927d60a4 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -212,10 +212,10 @@ mod position { // `pricing` contains pricing utils mod pricing { + mod error; mod position_pricing_utils; mod pricing_utils; mod swap_pricing_utils; - mod error; } // `referral` contains referral logic. @@ -314,6 +314,7 @@ mod tests { mod test_price; } mod pricing { + mod test_position_pricing_utils; mod test_swap_pricing_utils; } mod reader { diff --git a/src/market/market_utils.cairo b/src/market/market_utils.cairo index 8e160d42..c13c90be 100644 --- a/src/market/market_utils.cairo +++ b/src/market/market_utils.cairo @@ -21,6 +21,7 @@ use satoru::oracle::oracle::{IOracleDispatcher, IOracleDispatcherTrait}; use satoru::price::price::{Price, PriceTrait}; use satoru::utils::calc; use satoru::utils::span32::Span32; +use satoru::position::position::Position; use satoru::utils::i128::{I128Store, I128Serde, I128Div, I128Mul, I128Default}; /// Struct to store the prices of tokens of a market. @@ -973,6 +974,17 @@ fn market_token_amount_to_usd( 0 } +/// Get the virtual inventory for positions +/// # Arguments +/// * `dataStore` - DataStore +/// * `token` - the token to check +/// TODO internal function +fn get_virtual_inventory_for_positions( + dataStore: IDataStoreDispatcher, token: ContractAddress +) -> (bool, i128) { /// TODO + (true, 0) +} + /// Get the borrowing factor per second. /// # Arguments /// * `data_store` - The data store to use. @@ -988,6 +1000,75 @@ fn get_borrowing_factor_per_second( 0 } +fn get_adjusted_position_impact_factors( + data_store: IDataStoreDispatcher, market: ContractAddress +) -> (u128, u128) { // TODO + (0, 0) +} + +/// Get the borrowing fees for a position, assumes that cumulativeBorrowingFactor +/// has already been updated to the latest value +/// # Arguments +/// * `dataStore` - DataStore +/// * `position` - Position +/// * `dataStore` - DataStore +/// # Returns +/// The borrowing fees for a position +fn get_borrowing_fees(dataStore: IDataStoreDispatcher, position: Position) -> u128 { + 0 +} + +/// Get the funding fee amount per size for a market +/// # Arguments +/// * `dataStore` - DataStore +/// * `market` - the market to check +/// * `collateral_token` - the collateralToken to check +/// * `is_long` - whether to check the long or short size +/// # Returns +/// The funding fee amount per size for a market based on collateralToken +fn get_funding_fee_amount_per_size( + dataStore: IDataStoreDispatcher, + market: ContractAddress, + collateral_token: ContractAddress, + is_long: bool +) -> u128 { + 0 +} + +/// Get the claimable funding amount per size for a market +/// # Arguments +/// * `dataStore` - DataStore +/// * `market` - the market to check +/// * `collateral_token` - the collateralToken to check +/// * `is_long` - whether to check the long or short size +/// # Returns +/// The claimable funding amount per size for a market based on collateralToken +fn get_claimable_funding_amount_per_size( + dataStore: IDataStoreDispatcher, + market: ContractAddress, + collateral_token: ContractAddress, + is_long: bool +) -> u128 { + 0 +} + +/// Get the funding amount to be deducted or distributed +/// # Arguments +/// * `latestFundingAmountPerSize` - the latest funding amount per size +/// * `dataSpositionFundingAmountPerSizetore` - the funding amount per size for the position +/// * `positionSizeInUsd` - the position size in USD +/// * `roundUpMagnitude` - whether the round up the result +/// # Returns +/// fundingAmount +fn get_funding_amount( + latest_funding_amount_per_size: u128, + position_funding_amount_per_size: u128, + position_size_in_usd: u128, + round_up_magnitude: bool +) -> u128 { + 0 +} + /// Get the borrowing factor per second. /// # Arguments /// * `data_store` - The `DataStore` contract dispatcher. @@ -1098,17 +1179,9 @@ fn get_adjusted_swap_impact_factors( (positive_impact_factor, negative_impact_factor) } - -/// Get the virtual inventory for positions -/// # Arguments -/// * `data_store` - The data store to use. -/// * `token` - The token to check. -/// # Returns -/// has virtual inventory, virtual inventory -fn get_virtual_inventory_for_positions( - data_store: IDataStoreDispatcher, token: ContractAddress, -) -> (bool, i128) { +fn get_adjusted_position_impact_factor( + data_store: IDataStoreDispatcher, market: ContractAddress, isPositive: bool +) -> u128 { // TODO - (false, 0) + 0 } - diff --git a/src/pricing/error.cairo b/src/pricing/error.cairo index 26d03c87..fbeb08ae 100644 --- a/src/pricing/error.cairo +++ b/src/pricing/error.cairo @@ -1,3 +1,15 @@ mod PricingError { + fn USD_DELTA_EXCEEDS_LONG_OPEN_INTEREST(usd_delta: i128, long_open_interest: u128) { + let mut data = array!['usd delta exceeds long interest']; + data.append(usd_delta.into()); + data.append(long_open_interest.into()); + panic(data) + } + fn USD_DELTA_EXCEEDS_SHORT_OPEN_INTEREST(usd_delta: i128, short_open_interest: u128) { + let mut data = array!['usd delta exceed short interest']; + data.append(usd_delta.into()); + data.append(short_open_interest.into()); + panic(data) + } const USD_DELTA_EXCEEDS_POOL_VALUE: felt252 = 'usd_delta_exceeds_pool_value'; } diff --git a/src/pricing/position_pricing_utils.cairo b/src/pricing/position_pricing_utils.cairo index 05cf15df..cd30161c 100644 --- a/src/pricing/position_pricing_utils.cairo +++ b/src/pricing/position_pricing_utils.cairo @@ -13,8 +13,16 @@ use satoru::data::data_store::{IDataStoreDispatcher, IDataStoreDispatcherTrait}; use satoru::market::market::Market; use satoru::price::price::Price; use satoru::position::position::Position; + + +use satoru::market::market_utils; +use satoru::pricing::pricing_utils; +use satoru::data::keys; use satoru::mock::referral_storage::{IReferralStorageDispatcher, IReferralStorageDispatcherTrait}; -use satoru::utils::i128::{I128Store, I128Serde,}; +use satoru::utils::{calc, precision}; +use satoru::pricing::error::PricingError; +use satoru::referral::referral_utils; +use satoru::utils::i128::{I128Store, I128Serde, I128Div, I128Mul, I128Default}; /// Struct used in get_position_fees. #[derive(Drop, starknet::Store, Serde)] struct GetPositionFeesParams { @@ -39,7 +47,7 @@ struct GetPositionFeesParams { } /// Struct used in get_price_impact_usd. -#[derive(Drop, starknet::Store, Serde)] +#[derive(Drop, Copy, starknet::Store, Serde)] struct GetPriceImpactUsdParams { /// The `DataStore` contract dispatcher. data_store: IDataStoreDispatcher, @@ -137,7 +145,7 @@ struct PositionBorrowingFees { } /// Struct used to store position funding fees. -#[derive(Default, Drop, starknet::Store, Serde)] +#[derive(Default, Copy, Drop, starknet::Store, Serde)] struct PositionFundingFees { /// The amount of funding fees in tokens. funding_fee_amount: u128, @@ -165,19 +173,116 @@ struct PositionUiFees { } /// Get the price impact in USD for a position increase / decrease. +/// # Arguments +/// * `params` - GetPriceImpactUsdParams +/// # Returns +/// Price impact usd fn get_price_impact_usd(params: GetPriceImpactUsdParams) -> i128 { - // TODO - 0 + let open_interest_params: OpenInterestParams = get_next_open_interest(params); + let price_impact_usd = get_price_impact_usd_internal( + params.data_store, params.market.market_token, open_interest_params + ); + + /// the virtual price impact calculation is skipped if the price impact + /// is positive since the action is helping to balance the pool + /// + /// in case two virtual pools are unbalanced in a different direction + /// e.g. pool0 has more longs than shorts while pool1 has less longs + /// than shorts + /// not skipping the virtual price impact calculation would lead to + /// a negative price impact for any trade on either pools and would + /// disincentivise the balancing of pools + + if (price_impact_usd >= 0) { + return price_impact_usd; + } + + let (has_virtual_inventory, virtual_inventory) = + market_utils::get_virtual_inventory_for_positions( + params.data_store, params.market.index_token + ); + + if (!has_virtual_inventory) { + return price_impact_usd; + } + + let open_interest_params_for_virtual_inventory: OpenInterestParams = + get_next_open_interest_for_virtual_inventory( + params, virtual_inventory + ); + let price_impact_usd_for_virtual_inventory = get_price_impact_usd_internal( + params.data_store, params.market.market_token, open_interest_params_for_virtual_inventory + ); + + if (price_impact_usd_for_virtual_inventory < price_impact_usd) { + return price_impact_usd_for_virtual_inventory; + } + + price_impact_usd } /// Called internally by get_price_impact_params(). -fn get_price_impact_usd_( - params: GetPriceImpactUsdParams, +/// # Arguments +/// * `data_store` - DataStore +/// * `market` - the trading market +/// * `openInterestParams` - OpenInterestParams +/// # Returns +/// Price impact usd +fn get_price_impact_usd_internal( + data_store: IDataStoreDispatcher, market: ContractAddress, open_interest_params: OpenInterestParams, ) -> i128 { - // TODO - 0 + let initial_diff_usd = calc::diff( + open_interest_params.long_open_interest, open_interest_params.short_open_interest + ); + let next_diff_usd = calc::diff( + open_interest_params.next_long_open_interest, open_interest_params.next_short_open_interest + ); + + /// check whether an improvement in balance comes from causing the balance to switch sides + /// for example, if there is $2000 of ETH and $1000 of USDC in the pool + /// adding $1999 USDC into the pool will reduce absolute balance from $1000 to $999 but it does not + /// help rebalance the pool much, the isSameSideRebalance value helps avoid gaming using this case + let is_same_side_rebalance_first = open_interest_params + .long_open_interest <= open_interest_params + .short_open_interest; + let is_same_side_rebalance_second = open_interest_params + .short_open_interest <= open_interest_params + .next_long_open_interest; + let is_same_side_rebalance_third = open_interest_params + .next_long_open_interest <= open_interest_params + .next_short_open_interest; + let is_same_side_rebalance = is_same_side_rebalance_first + && is_same_side_rebalance_second + && is_same_side_rebalance_third; + + let impact_exponent_factor = data_store + .get_u128(keys::position_impact_exponent_factor_key(market)); + + if (is_same_side_rebalance) { + let has_positive_impact = next_diff_usd < initial_diff_usd; + let impact_factor = market_utils::get_adjusted_position_impact_factor( + data_store, market, has_positive_impact + ); + + return pricing_utils::get_price_impact_usd_for_same_side_rebalance( + initial_diff_usd, next_diff_usd, impact_factor, impact_exponent_factor + ); + } else { + let (positive_impact_factor, negative_impact_factor) = + market_utils::get_adjusted_position_impact_factors( + data_store, market + ); + + return pricing_utils::get_price_impact_usd_for_crossover_rebalance( + initial_diff_usd, + next_diff_usd, + positive_impact_factor, + negative_impact_factor, + impact_exponent_factor + ); + } } /// Compute new open interest. @@ -186,13 +291,15 @@ fn get_price_impact_usd_( /// # Returns /// New open interest. fn get_next_open_interest(params: GetPriceImpactUsdParams) -> OpenInterestParams { - // TODO - OpenInterestParams { - long_open_interest: 0, - short_open_interest: 0, - next_long_open_interest: 0, - next_short_open_interest: 0, - } + let long_open_interest = market_utils::get_open_interest_for_market_is_long( + params.data_store, @params.market, true + ); + + let short_open_interest = market_utils::get_open_interest_for_market_is_long( + params.data_store, @params.market, false + ); + + return get_next_open_interest_params(params, long_open_interest, short_open_interest); } /// Compute new open interest for virtual inventory. @@ -204,32 +311,69 @@ fn get_next_open_interest(params: GetPriceImpactUsdParams) -> OpenInterestParams fn get_next_open_interest_for_virtual_inventory( params: GetPriceImpactUsdParams, virtual_inventory: i128, ) -> OpenInterestParams { - // TODO - OpenInterestParams { - long_open_interest: 0, - short_open_interest: 0, - next_long_open_interest: 0, - next_short_open_interest: 0, + let mut long_open_interest = 0; + let mut short_open_interest = 0; + + /// if virtualInventory is more than zero it means that + /// tokens were virtually sold to the pool, so set shortOpenInterest + /// to the virtualInventory value + /// if virtualInventory is less than zero it means that + /// tokens were virtually bought from the pool, so set longOpenInterest + /// to the virtualInventory value + + if (virtual_inventory > 0) { + short_open_interest = calc::to_unsigned(virtual_inventory); + } else { + long_open_interest = calc::to_unsigned(-virtual_inventory); } + + /// the virtual long and short open interest is adjusted by the usdDelta + /// to prevent an underflow in getNextOpenInterestParams + /// price impact depends on the change in USD balance, so offsetting both + /// values equally should not change the price impact calculation + if (params.usd_delta < 0) { + let offset = calc::to_unsigned(-params.usd_delta); + long_open_interest += offset; + short_open_interest += offset; + } + + return get_next_open_interest_params(params, long_open_interest, short_open_interest); } /// Compute new open interest. /// # Arguments /// * `params` - Price impact in usd. /// * `long_open_interest` - Long positions open interest. -/// * `long_open_interest` - Short positions open interest. +/// * `short _open_interest` - Short positions open interest. /// # Returns /// New open interest. fn get_next_open_interest_params( params: GetPriceImpactUsdParams, long_open_interest: u128, short_open_interest: u128 ) -> OpenInterestParams { - // TODO - OpenInterestParams { - long_open_interest: 0, - short_open_interest: 0, - next_long_open_interest: 0, - next_short_open_interest: 0, + let mut next_long_open_interest = long_open_interest; + let mut next_short_open_interest = short_open_interest; + + if (params.is_long) { + if (params.usd_delta < 0 && calc::to_unsigned(-params.usd_delta) > long_open_interest) { + PricingError::USD_DELTA_EXCEEDS_LONG_OPEN_INTEREST(params.usd_delta, long_open_interest) + } + + next_long_open_interest = calc::sum_return_uint_128(long_open_interest, params.usd_delta); + } else { + if (params.usd_delta < 0 && calc::to_unsigned(-params.usd_delta) > short_open_interest) { + PricingError::USD_DELTA_EXCEEDS_SHORT_OPEN_INTEREST( + params.usd_delta, short_open_interest + ) + } + + next_short_open_interest = calc::sum_return_uint_128(short_open_interest, params.usd_delta); } + + let open_interest_params = OpenInterestParams { + long_open_interest, short_open_interest, next_long_open_interest, next_short_open_interest + }; + + open_interest_params } /// Compute position fees. @@ -237,53 +381,78 @@ fn get_next_open_interest_params( /// * `params` - parameters to compute position fees. /// # Returns /// Position fees. -fn get_position_fees(params: GetPositionFeesParams,) -> PositionFees { - // TODO - let address_zero: ContractAddress = 0.try_into().unwrap(); - let position_referral_fees = PositionReferralFees { - referral_code: 0, - affiliate: address_zero, - trader: address_zero, - total_rebate_factor: 0, - trader_discount_factor: 0, - total_rebate_amount: 0, - trader_discount_amount: 0, - affiliate_reward_amount: 0, - }; - let position_funding_fees = PositionFundingFees { - funding_fee_amount: 0, - claimable_long_token_amount: 0, - claimable_short_token_amount: 0, - latest_funding_fee_amount_per_size: 0, - latest_long_token_claimable_funding_amount_per_size: 0, - latest_short_token_claimable_funding_amount_per_size: 0, - }; - let position_borrowing_fees = PositionBorrowingFees { - borrowing_fee_usd: 0, - borrowing_fee_amount: 0, - borrowing_fee_receiver_factor: 0, - borrowing_fee_amount_for_fee_receiver: 0, - }; - let position_ui_fees = PositionUiFees { - ui_fee_receiver: address_zero, ui_fee_receiver_factor: 0, ui_fee_amount: 0, - }; - let price = Price { min: 0, max: 0, }; - PositionFees { - referral: position_referral_fees, - funding: position_funding_fees, - borrowing: position_borrowing_fees, - ui: position_ui_fees, - collateral_token_price: price, - position_fee_factor: 0, - protocol_fee_amount: 0, - position_fee_receiver_factor: 0, - fee_receiver_amount: 0, - fee_amount_for_pool: 0, - position_fee_amount_for_pool: 0, - position_fee_amount: 0, - total_cost_amount_excluding_funding: 0, - total_cost_amount: 0, - } +fn get_position_fees(params: GetPositionFeesParams) -> PositionFees { + let mut fees = get_position_fees_after_referral( + params.data_store, + params.referral_storage, + params.collateral_token_price, + params.for_positive_impact, + params.position.account, + params.position.market, + params.size_delta_usd + ); + + let borrowing_fee_usd = market_utils::get_borrowing_fees(params.data_store, params.position); + + fees + .borrowing = + get_borrowing_fees(params.data_store, params.collateral_token_price, borrowing_fee_usd); + + fees.fee_amount_for_pool = fees.position_fee_amount_for_pool + + fees.borrowing.borrowing_fee_amount + - fees.borrowing.borrowing_fee_amount_for_fee_receiver; + fees.fee_receiver_amount += fees.borrowing.borrowing_fee_amount_for_fee_receiver; + + fees + .funding + .latest_funding_fee_amount_per_size = + market_utils::get_funding_fee_amount_per_size( + params.data_store, + params.position.market, + params.position.collateral_token, + params.position.is_long + ); + + fees + .funding + .latest_long_token_claimable_funding_amount_per_size = + market_utils::get_claimable_funding_amount_per_size( + params.data_store, + params.position.market, + params.long_token, + params.position.is_long + ); + + fees + .funding + .latest_short_token_claimable_funding_amount_per_size = + market_utils::get_claimable_funding_amount_per_size( + params.data_store, + params.position.market, + params.short_token, + params.position.is_long + ); + + fees.funding = get_funding_fees(fees.funding, params.position); + + fees + .ui = + get_ui_fees( + params.data_store, + params.collateral_token_price, + params.size_delta_usd, + params.ui_fee_receiver + ); + + fees.total_cost_amount_excluding_funding = fees.position_fee_amount + + fees.borrowing.borrowing_fee_amount + + fees.ui.ui_fee_amount + - fees.referral.trader_discount_amount; + + fees.total_cost_amount = fees.total_cost_amount_excluding_funding + + fees.funding.funding_fee_amount; + + fees } /// Compute borrowing fees data. @@ -296,12 +465,15 @@ fn get_position_fees(params: GetPositionFeesParams,) -> PositionFees { fn get_borrowing_fees( data_store: IDataStoreDispatcher, collateral_token_price: Price, borrowing_fee_usd: u128, ) -> PositionBorrowingFees { - // TODO + let borrowing_fee_amount = borrowing_fee_usd / collateral_token_price.min; + let borrowing_fee_receiver_factor = data_store.get_u128(keys::borrowing_fee_receiver_factor()); PositionBorrowingFees { - borrowing_fee_usd: 0, - borrowing_fee_amount: 0, - borrowing_fee_receiver_factor: 0, - borrowing_fee_amount_for_fee_receiver: 0, + borrowing_fee_usd, + borrowing_fee_amount, + borrowing_fee_receiver_factor, + borrowing_fee_amount_for_fee_receiver: precision::apply_factor_u128( + borrowing_fee_amount, borrowing_fee_receiver_factor + ) } } @@ -310,35 +482,68 @@ fn get_borrowing_fees( /// * `funding_fees` - The position funding fees struct to store fees. /// * `position` - The position to compute funding fees for. /// # Returns -/// Borrowing fees. -fn get_funding_fees(funding_fees: PositionFundingFees, position: Position,) -> PositionFundingFees { - // TODO - PositionFundingFees { - funding_fee_amount: 0, - claimable_long_token_amount: 0, - claimable_short_token_amount: 0, - latest_funding_fee_amount_per_size: 0, - latest_long_token_claimable_funding_amount_per_size: 0, - latest_short_token_claimable_funding_amount_per_size: 0, - } +/// Funding fees. +fn get_funding_fees( + mut funding_fees: PositionFundingFees, position: Position +) -> PositionFundingFees { + funding_fees + .funding_fee_amount = + market_utils::get_funding_amount( + funding_fees.latest_funding_fee_amount_per_size, + position.funding_fee_amount_per_size, + position.size_in_usd, + true // roundUpMagnitude + ); + + funding_fees + .claimable_long_token_amount = + market_utils::get_funding_amount( + funding_fees.latest_long_token_claimable_funding_amount_per_size, + position.long_token_claimable_funding_amount_per_size, + position.size_in_usd, + false // roundUpMagnitude + ); + + funding_fees + .claimable_short_token_amount = + market_utils::get_funding_amount( + funding_fees.latest_short_token_claimable_funding_amount_per_size, + position.short_token_claimable_funding_amount_per_size, + position.size_in_usd, + false // roundUpMagnitude + ); + + funding_fees } /// Compute ui fees. /// # Arguments /// * `data_store` - The `DataStore` contract dispatcher. /// * `collateral_token_price` - The price of the collateral token. +/// * `size_delta_usd` - Size delta usd /// * `ui_fee receiver` - The ui fee receiver address. /// # Returns -/// Borrowing fees. +/// Ui fees. fn get_ui_fees( data_store: IDataStoreDispatcher, collateral_token_price: Price, size_delta_usd: u128, ui_fee_receiver: ContractAddress ) -> PositionUiFees { - // TODO - let address_zero: ContractAddress = 0.try_into().unwrap(); - PositionUiFees { ui_fee_receiver: address_zero, ui_fee_receiver_factor: 0, ui_fee_amount: 0, } + let mut ui_fees: PositionUiFees = Default::default(); + + if (ui_fee_receiver == 0.try_into().unwrap()) { + return ui_fees; + } + + ui_fees.ui_fee_receiver = ui_fee_receiver; + ui_fees.ui_fee_receiver_factor = market_utils::get_ui_fee_factor(data_store, ui_fee_receiver); + ui_fees + .ui_fee_amount = + precision::apply_factor_u128(size_delta_usd, ui_fees.ui_fee_receiver_factor) + / collateral_token_price.min; + + ui_fees } /// Get position fees after applying referral rebates / discounts. @@ -360,50 +565,59 @@ fn get_position_fees_after_referral( market: ContractAddress, size_delta_usd: u128, ) -> PositionFees { - // TODO - let address_zero: ContractAddress = 0.try_into().unwrap(); - let position_referral_fees = PositionReferralFees { - referral_code: 0, - affiliate: address_zero, - trader: address_zero, - total_rebate_factor: 0, - trader_discount_factor: 0, - total_rebate_amount: 0, - trader_discount_amount: 0, - affiliate_reward_amount: 0, - }; - let position_funding_fees = PositionFundingFees { - funding_fee_amount: 0, - claimable_long_token_amount: 0, - claimable_short_token_amount: 0, - latest_funding_fee_amount_per_size: 0, - latest_long_token_claimable_funding_amount_per_size: 0, - latest_short_token_claimable_funding_amount_per_size: 0, - }; - let position_borrowing_fees = PositionBorrowingFees { - borrowing_fee_usd: 0, - borrowing_fee_amount: 0, - borrowing_fee_receiver_factor: 0, - borrowing_fee_amount_for_fee_receiver: 0, - }; - let position_ui_fees = PositionUiFees { - ui_fee_receiver: address_zero, ui_fee_receiver_factor: 0, ui_fee_amount: 0, - }; - let price = Price { min: 0, max: 0, }; - PositionFees { - referral: position_referral_fees, - funding: position_funding_fees, - borrowing: position_borrowing_fees, - ui: position_ui_fees, - collateral_token_price: price, - position_fee_factor: 0, - protocol_fee_amount: 0, - position_fee_receiver_factor: 0, - fee_receiver_amount: 0, - fee_amount_for_pool: 0, - position_fee_amount_for_pool: 0, - position_fee_amount: 0, - total_cost_amount_excluding_funding: 0, - total_cost_amount: 0, - } + let mut fees: PositionFees = Default::default(); + + fees.collateral_token_price = collateral_token_price; + + fees.referral.trader = account; + + let (referral_code, affiliate, total_rebate_factor, trader_discount_factor) = + referral_utils::get_referral_info( + referral_storage, account + ); + + fees.referral.referral_code = referral_code; + fees.referral.affiliate = affiliate; + fees.referral.total_rebate_factor = total_rebate_factor; + fees.referral.trader_discount_factor = trader_discount_factor; + + /// note that since it is possible to incur both positive and negative price impact values + /// and the negative price impact factor may be larger than the positive impact factor + /// it is possible for the balance to be improved overall but for the price impact to still be negative + /// in this case the fee factor for the negative price impact would be charged + /// a user could split the order into two, to incur a smaller fee, reducing the fee through this should not be a large issue + fees + .position_fee_factor = data_store + .get_u128(keys::position_fee_factor_key(market, for_positive_impact)); + fees + .position_fee_amount = + precision::apply_factor_u128(size_delta_usd, fees.position_fee_factor) + / collateral_token_price.min; + + fees + .referral + .total_rebate_amount = + precision::apply_factor_u128( + fees.position_fee_amount, fees.referral.total_rebate_factor + ); + fees + .referral + .trader_discount_amount = + precision::apply_factor_u128( + fees.referral.total_rebate_amount, fees.referral.trader_discount_factor + ); + fees.referral.affiliate_reward_amount = fees.referral.total_rebate_amount + - fees.referral.trader_discount_amount; + + fees.protocol_fee_amount = fees.position_fee_amount - fees.referral.total_rebate_amount; + + fees.position_fee_receiver_factor = data_store.get_u128(keys::position_fee_receiver_factor()); + fees + .fee_receiver_amount = + precision::apply_factor_u128( + fees.protocol_fee_amount, fees.position_fee_receiver_factor + ); + fees.position_fee_amount_for_pool = fees.protocol_fee_amount - fees.fee_receiver_amount; + + fees } diff --git a/src/tests/pricing/test_position_pricing_utils.cairo b/src/tests/pricing/test_position_pricing_utils.cairo new file mode 100644 index 00000000..855c9cc1 --- /dev/null +++ b/src/tests/pricing/test_position_pricing_utils.cairo @@ -0,0 +1,216 @@ +use starknet::{ContractAddress, contract_address_const}; +use satoru::price::price::Price; +use satoru::position::position::Position; +use satoru::pricing::position_pricing_utils; +use satoru::mock::referral_storage::{IReferralStorageDispatcher, IReferralStorageDispatcherTrait}; +use satoru::event::event_emitter::{IEventEmitterDispatcher, IEventEmitterDispatcherTrait}; +use satoru::mock::governable::{IGovernableDispatcher, IGovernableDispatcherTrait}; +use satoru::data::data_store::{IDataStoreDispatcher, IDataStoreDispatcherTrait}; +use satoru::role::role_store::{IRoleStoreDispatcher, IRoleStoreDispatcherTrait}; +use satoru::role::role; +use satoru::market::market::Market; +use satoru::pricing::position_pricing_utils::{ + GetPositionFeesParams, PositionFundingFees, GetPriceImpactUsdParams +}; +use snforge_std::{declare, start_prank, stop_prank, ContractClassTrait}; + +// TODO add asserts for each test when possible + +#[test] +fn given_normal_conditions_when_get_price_impact_usd_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + + let get_price_impact_params = create_get_price_impact_usd_params(data_store); + position_pricing_utils::get_price_impact_usd(get_price_impact_params); +} + +#[test] +fn given_normal_conditions_when_get_next_open_interest_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + + let get_price_impact_params = create_get_price_impact_usd_params(data_store); + position_pricing_utils::get_next_open_interest(get_price_impact_params); +} + +#[test] +fn given_normal_conditions_when_get_next_open_interest_for_virtual_inventory_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + + let get_price_impact_params = create_get_price_impact_usd_params(data_store); + position_pricing_utils::get_next_open_interest_for_virtual_inventory( + get_price_impact_params, 50 + ); +} + +#[test] +fn given_normal_conditions_when_get_next_open_interest_params_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + + let get_price_impact_params = create_get_price_impact_usd_params(data_store); + position_pricing_utils::get_next_open_interest_params(get_price_impact_params, 100, 20); +} + +#[test] +fn given_normal_conditions_when_get_position_fees_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + + let position = Position { + key: 1, + account: contract_address_const::<'account'>(), + market: contract_address_const::<'market'>(), + collateral_token: contract_address_const::<'collateral_token'>(), + size_in_usd: 100, + size_in_tokens: 1, + collateral_amount: 2, + borrowing_factor: 3, + funding_fee_amount_per_size: 4, + long_token_claimable_funding_amount_per_size: 5, + short_token_claimable_funding_amount_per_size: 6, + increased_at_block: 15000, + decreased_at_block: 15001, + is_long: false + }; + + let collateral_token_price = Price { min: 5, max: 10, }; + + GetPositionFeesParams { + data_store, + referral_storage, + position, + collateral_token_price, + for_positive_impact: true, + long_token: contract_address_const::<'long_token'>(), + short_token: contract_address_const::<'short_token'>(), + size_delta_usd: 10, + ui_fee_receiver: contract_address_const::<'ui_fee_receiver'>() + }; +} + +#[test] +fn given_normal_conditions_when_get_borrowing_fees_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + let price = Price { min: 5, max: 10 }; + + position_pricing_utils::get_borrowing_fees(data_store, price, 3); +} + +#[test] +fn given_normal_conditions_when_get_funding_fees_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + let position_funding_fees = PositionFundingFees { + funding_fee_amount: 10, + claimable_long_token_amount: 100, + claimable_short_token_amount: 50, + latest_funding_fee_amount_per_size: 15, + latest_long_token_claimable_funding_amount_per_size: 15, + latest_short_token_claimable_funding_amount_per_size: 15, + }; + + let position = Position { + key: 1, + account: contract_address_const::<'account'>(), + market: contract_address_const::<'market'>(), + collateral_token: contract_address_const::<'collateral_token'>(), + size_in_usd: 100, + size_in_tokens: 1, + collateral_amount: 2, + borrowing_factor: 3, + funding_fee_amount_per_size: 4, + long_token_claimable_funding_amount_per_size: 5, + short_token_claimable_funding_amount_per_size: 6, + increased_at_block: 15000, + decreased_at_block: 15001, + is_long: false + }; + + position_pricing_utils::get_funding_fees(position_funding_fees, position); +} + +#[test] +fn given_normal_conditions_when_get_ui_fees_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + let price = Price { min: 5, max: 10 }; + let ui_fee_receiver = contract_address_const::<'ui_fee_receiver'>(); + position_pricing_utils::get_ui_fees(data_store, price, 10, ui_fee_receiver); +} + +#[test] +fn given_normal_conditions_when_get_position_fees_after_referral_then_works() { + let (caller_address, data_store, referral_storage) = setup(); + let price = Price { min: 5, max: 10 }; + let account = contract_address_const::<'account'>(); + let market = contract_address_const::<'market'>(); + position_pricing_utils::get_position_fees_after_referral( + data_store, referral_storage, price, true, account, market, 10 + ); +} + +fn deploy_data_store(role_store_address: ContractAddress) -> ContractAddress { + let contract = declare('DataStore'); + let constructor_calldata = array![role_store_address.into()]; + contract.deploy(@constructor_calldata).unwrap() +} + +fn deploy_role_store() -> ContractAddress { + let contract = declare('RoleStore'); + contract.deploy(@array![]).unwrap() +} + +fn deploy_referral_storage(event_emitter_address: ContractAddress) -> ContractAddress { + let contract = declare('ReferralStorage'); + let constructor_calldata = array![event_emitter_address.into()]; + contract.deploy(@constructor_calldata).unwrap() +} + +fn deploy_governable(event_emitter_address: ContractAddress) -> ContractAddress { + let contract = declare('Governable'); + let constructor_calldata = array![event_emitter_address.into()]; + contract.deploy(@constructor_calldata).unwrap() +} + +fn deploy_event_emitter() -> ContractAddress { + let contract = declare('EventEmitter'); + contract.deploy(@array![]).unwrap() +} + +fn setup() -> (ContractAddress, IDataStoreDispatcher, IReferralStorageDispatcher) { + let caller_address: ContractAddress = 0x101.try_into().unwrap(); + + let role_store_address = deploy_role_store(); + let role_store = IRoleStoreDispatcher { contract_address: role_store_address }; + + let data_store_address = deploy_data_store(role_store_address); + let data_store = IDataStoreDispatcher { contract_address: data_store_address }; + + let event_emitter_address = deploy_event_emitter(); + let event_emitter = IEventEmitterDispatcher { contract_address: event_emitter_address }; + + let referral_storage_address = deploy_referral_storage(event_emitter_address); + let referral_storage = IReferralStorageDispatcher { + contract_address: referral_storage_address + }; + let governable_address = deploy_governable(event_emitter_address); + let governable = IGovernableDispatcher { contract_address: governable_address }; + + start_prank(event_emitter_address, caller_address); + start_prank(data_store_address, caller_address); + start_prank(referral_storage_address, caller_address); + start_prank(governable_address, caller_address); + + start_prank(role_store_address, caller_address); + role_store.grant_role(caller_address, role::CONTROLLER); + start_prank(data_store_address, caller_address); + (caller_address, data_store, referral_storage) +} + + +fn create_get_price_impact_usd_params(data_store: IDataStoreDispatcher) -> GetPriceImpactUsdParams { + let market = Market { + market_token: contract_address_const::<'market_token'>(), + index_token: contract_address_const::<'index_token'>(), + long_token: contract_address_const::<'long_token'>(), + short_token: contract_address_const::<'short_token'>() + }; + + GetPriceImpactUsdParams { data_store, market, usd_delta: 50, is_long: true } +} From ac48e57e82a5f607a45f98f8e481f12ad561fd2a Mon Sep 17 00:00:00 2001 From: Alex Metelli Date: Mon, 25 Sep 2023 15:26:55 -0700 Subject: [PATCH 2/2] Feat: Completed Implementation of `role_store` contract. (#443) completed role_store + tests Co-authored-by: zarboq <37303126+zarboq@users.noreply.github.com> --- src/role/role_store.cairo | 249 +++++++++++++++++++-------- src/tests/role/test_role_store.cairo | 217 ++++++++++++----------- 2 files changed, 285 insertions(+), 181 deletions(-) diff --git a/src/role/role_store.cairo b/src/role/role_store.cairo index 7ee4f6a7..eb78928a 100644 --- a/src/role/role_store.cairo +++ b/src/role/role_store.cairo @@ -43,32 +43,30 @@ trait IRoleStore { /// Returns the number of roles stored in the contract. /// # Return /// The number of roles. - fn get_role_count(self: @TContractState) -> u128; -/// # [TO FIX] -/// Returns the keys of the roles stored in the contract. -/// # Arguments -/// `start` - The starting index of the range of roles to return. -/// `end` - The ending index of the range of roles to return. -/// # Return -/// The keys of the roles. -/// fn get_roles(self: @TContractState, start: u32, end: u32) -> Array; - -/// # [TO DO] -/// Returns the number of members of the specified role. -/// # Arguments -/// `role_key` - The key of the role. -/// # Return -/// The number of members of the role. -/// fn get_role_member_count(self: @TContractState, role_key: felt252) -> u128; - -/// Returns the members of the specified role. -/// # Arguments -/// `role_key` - The key of the role. -/// `start` - The start index, the value for this index will be included. -/// `end` - The end index, the value for this index will not be included. -/// # Return -/// The members of the role. -/// fn get_role_members(self: @TContractState, role_key: felt252, start: u128, end: u128) -> Array; + fn get_role_count(self: @TContractState) -> u32; + /// Returns the keys of the roles stored in the contract. + /// # Arguments + /// `start` - The starting index of the range of roles to return. + /// `end` - The ending index of the range of roles to return. + /// # Return + /// The keys of the roles. + fn get_roles(self: @TContractState, start: u32, end: u32) -> Array; + /// Returns the number of members of the specified role. + /// # Arguments + /// `role_key` - The key of the role. + /// # Return + /// The number of members of the role. + fn get_role_member_count(self: @TContractState, role_key: felt252) -> u32; + /// Returns the members of the specified role. + /// # Arguments + /// `role_key` - The key of the role. + /// `start` - The start index, the value for this index will be included. + /// `end` - The end index, the value for this index will not be included. + /// # Return + /// The members of the role. + fn get_role_members( + self: @TContractState, role_key: felt252, start: u32, end: u32 + ) -> Array; } #[starknet::contract] @@ -78,8 +76,8 @@ mod RoleStore { // ************************************************************************* // Core lib imports. - use starknet::{ContractAddress, get_caller_address}; - //use array::ArrayTrait; + use core::zeroable::Zeroable; + use starknet::{ContractAddress, get_caller_address, contract_address_const}; // Local imports. use satoru::role::{role, error::RoleError}; @@ -90,13 +88,17 @@ mod RoleStore { #[storage] struct Storage { /// Maps accounts to their roles. - role_members: LegacyMap::<(felt252, ContractAddress), bool>, + has_role: LegacyMap::<(felt252, ContractAddress), bool>, + /// Stores the number of the indexes used to a specific role. + role_members_count: LegacyMap::, + /// Stores all the account that have a specific role. + role_members: LegacyMap::<(felt252, u32), ContractAddress>, /// Stores unique role names. role_names: LegacyMap::, - /// Store the number of unique roles. - role_count: u128, - /// List of all role keys. - ///role_keys: Array, + /// Store the number of indexes of the roles. + roles_count: u32, + /// List of all role keys. + roles: LegacyMap::, } // ************************************************************************* @@ -140,8 +142,7 @@ mod RoleStore { let caller = get_caller_address(); // Grant the caller admin role. self._grant_role(caller, role::ROLE_ADMIN); - // Initialize the role_count to 1 due to the line just above. - self.role_count.write(1); + // Initialize the role_count to 1 due to the line just above. } // ************************************************************************* @@ -154,17 +155,15 @@ mod RoleStore { } fn grant_role(ref self: ContractState, account: ContractAddress, role_key: felt252) { - let caller = get_caller_address(); // Check that the caller has the admin role. - self._assert_only_role(caller, role::ROLE_ADMIN); + self._assert_only_role(get_caller_address(), role::ROLE_ADMIN); // Grant the role. self._grant_role(account, role_key); } fn revoke_role(ref self: ContractState, account: ContractAddress, role_key: felt252) { - let caller = get_caller_address(); // Check that the caller has the admin role. - self._assert_only_role(caller, role::ROLE_ADMIN); + self._assert_only_role(get_caller_address(), role::ROLE_ADMIN); // Revoke the role. self._revoke_role(account, role_key); } @@ -173,29 +172,78 @@ mod RoleStore { self._assert_only_role(account, role_key); } - fn get_role_count(self: @ContractState) -> u128 { - return self.role_count.read(); + fn get_role_count(self: @ContractState) -> u32 { + let mut count = 0; + let mut i = 1; + loop { + if i > self.roles_count.read() { + break; + } + if !self.roles.read(i).is_zero() { + count += 1; + } + i += 1; + }; + count + } + + fn get_roles(self: @ContractState, start: u32, mut end: u32) -> Array { + let mut arr = array![]; + let roles_count = self.roles_count.read(); + if end > roles_count { + end = roles_count; + } + let mut i = start; + loop { + if i > end { + break; + } + let role = self.roles.read(i); + if !role.is_zero() { + arr.append(role); + } + i += 1; + }; + arr + } + + fn get_role_member_count(self: @ContractState, role_key: felt252) -> u32 { + let mut count = 0; + let mut i = 1; + loop { + if i > self.role_members_count.read(role_key) { + break; + } + if !(self.role_members.read((role_key, i)) == contract_address_const::<0>()) { + count += 1; + } + i += 1; + }; + count + } + + fn get_role_members( + self: @ContractState, role_key: felt252, start: u32, mut end: u32 + ) -> Array { + let mut arr: Array = array![]; + let mut i = start; + loop { + if i > end || i > self.role_members_count.read(role_key) { + break; + } + let role_member = self.role_members.read((role_key, i)); + // Since some role members will have indexes with zero address if a zero address + // is found end increase by 1 to mock array behaviour. + if role_member.is_zero() { + end += 1; + } + if !(role_member == contract_address_const::<0>()) { + arr.append(role_member); + } + i += 1; + }; + arr } - //fn get_roles(self: @ContractState, start: u32, end: u32) -> Array { - // Create a new array to store the result. - //let mut result = ArrayTrait::::new(); - //let role_keys_length = self.role_keys.read().len(); - // Ensure the range is valid. - //assert(start < end, "InvalidRange"); - //assert(end <= role_keys_length, "EndOutOfBounds"); - //let mut current_index = start; - //loop { - // Check if we've reached the end of the specified range. - //if current_index >= end { - //break; - //} - //let key = *self.role_keys.read().at(current_index); - //result.append(key); - // Increment the index. - //current_index += 1; - //}; - //return result; - //} } // ************************************************************************* @@ -205,7 +253,7 @@ mod RoleStore { impl InternalFunctions of InternalFunctionsTrait { #[inline(always)] fn _has_role(self: @ContractState, account: ContractAddress, role_key: felt252) -> bool { - self.role_members.read((role_key, account)) + self.has_role.read((role_key, account)) } #[inline(always)] @@ -216,31 +264,80 @@ mod RoleStore { fn _grant_role(ref self: ContractState, account: ContractAddress, role_key: felt252) { // Only grant the role if the account doesn't already have it. if !self._has_role(account, role_key) { - let caller: ContractAddress = get_caller_address(); - self.role_members.write((role_key, account), true); + self.has_role.write((role_key, account), true); + // Iterates through indexes for role members, if an index has zero ContractAddress + // it writes the account to that index for the role. + let roles_members_count = self.role_members_count.read(role_key); + let current_roles_count = self.roles_count.read(); + let mut i = 1; + loop { + let stored_role_member = self.role_members.read((role_key, i)); + if stored_role_member.is_zero() { + self.role_members.write((role_key, i), account); + self.role_members_count.write(role_key, roles_members_count + 1); + break; + } + i += 1; + }; + // Store the role name if it's not already stored. if self.role_names.read(role_key) == false { self.role_names.write(role_key, true); - let mut current_count: u128 = self.role_count.read(); - self.role_count.write(current_count + 1); - // Read the current state of role_keys into a local variable. - // let mut local_role_keys = self.role_keys.read(); - // Modify the local variable. - // local_role_keys.append(role_key); - // Write back the modified local variable to the contract state. - // self.role_keys.write(local_role_keys); + self.roles_count.write(current_roles_count + 1); } - self.emit(RoleGranted { role_key, account, sender: caller }); + self.emit(RoleGranted { role_key, account, sender: get_caller_address() }); } + // Iterates through indexes in stored_roles and if a value for the index is zero + // it writes the role_key to that index. + let mut i = 1; + loop { + let stored_role = self.roles.read(i); + if stored_role.is_zero() { + self.roles.write(i, role_key); + break; + } + i += 1; + }; } fn _revoke_role(ref self: ContractState, account: ContractAddress, role_key: felt252) { + let current_roles_count = self.roles_count.read(); // Only revoke the role if the account has it. if self._has_role(account, role_key) { - let caller: ContractAddress = get_caller_address(); - self.role_members.write((role_key, account), false); - self.emit(RoleRevoked { role_key, account, sender: caller }); + self.has_role.write((role_key, account), false); + self.emit(RoleRevoked { role_key, account, sender: get_caller_address() }); + let current_role_members_count = self.role_members_count.read(role_key); + let mut i = 1; + loop { + let stored_role_member = self.role_members.read((role_key, i)); + if stored_role_member == account { + self.role_members.write((role_key, i), contract_address_const::<0>()); + break; + } + i += 1; + }; + // If the role has no members remove the role from roles. + if self.get_role_member_count(role_key).is_zero() { + let role_index = self._find_role_index(role_key); + self.roles.write(role_index, Zeroable::zero()); + } } } + + fn _find_role_index(ref self: ContractState, role_key: felt252) -> u32 { + let mut index = 0; + let mut i = 1; + loop { + if i > self.roles_count.read() { + break; + } + if self.roles.read(i) == role_key { + index = i; + break; + } + i += 1; + }; + index + } } } diff --git a/src/tests/role/test_role_store.cairo b/src/tests/role/test_role_store.cairo index 0fe91617..1e23d0b5 100644 --- a/src/tests/role/test_role_store.cairo +++ b/src/tests/role/test_role_store.cairo @@ -2,7 +2,7 @@ use result::ResultTrait; use traits::TryInto; use starknet::{ContractAddress, contract_address_const}; use starknet::Felt252TryIntoContractAddress; -use snforge_std::{declare, start_prank, ContractClassTrait}; +use snforge_std::{declare, start_prank, ContractClassTrait, PrintTrait}; //use array::ArrayTrait; use satoru::role::role::ROLE_ADMIN; @@ -12,158 +12,165 @@ use satoru::role::role_store::IRoleStoreDispatcher; use satoru::role::role_store::IRoleStoreDispatcherTrait; #[test] -fn given_normal_conditions_when_grant_role_then_works() { - // ********************************************************************************************* - // * SETUP * - // ********************************************************************************************* +fn given_normal_conditions_when_has_role_after_grant_then_works() { let role_store = setup(); - // ********************************************************************************************* - // * TEST LOGIC * - // ********************************************************************************************* // Use the address that has been used to deploy role_store. - let caller_address: ContractAddress = 0x101.try_into().unwrap(); - start_prank(role_store.contract_address, caller_address); - - let account_address: ContractAddress = contract_address_const::<1>(); + start_prank(role_store.contract_address, admin()); // Check that the account address does not have the admin role. - assert(!role_store.has_role(account_address, ROLE_ADMIN), 'Invalid role'); + assert(!role_store.has_role(account_1(), ROLE_ADMIN), 'Invalid role'); // Grant admin role to account address. - role_store.grant_role(account_address, ROLE_ADMIN); + role_store.grant_role(account_1(), ROLE_ADMIN); // Check that the account address has the admin role. - assert(role_store.has_role(account_address, ROLE_ADMIN), 'Invalid role'); - - // ********************************************************************************************* - // * TEARDOWN * - // ********************************************************************************************* - teardown(); + assert(role_store.has_role(account_1(), ROLE_ADMIN), 'Invalid role'); } #[test] -fn given_normal_conditions_when_revoke_role_then_works() { - // ********************************************************************************************* - // * SETUP * - // ********************************************************************************************* +fn given_normal_conditions_when_has_role_after_revoke_then_works() { let role_store = setup(); - // ********************************************************************************************* - // * TEST LOGIC * - // ********************************************************************************************* - // Use the address that has been used to deploy role_store. - let caller_address: ContractAddress = 0x101.try_into().unwrap(); - start_prank(role_store.contract_address, caller_address); - - let account_address: ContractAddress = contract_address_const::<1>(); + start_prank(role_store.contract_address, admin()); // Grant admin role to account address. - role_store.grant_role(account_address, ROLE_ADMIN); + role_store.grant_role(account_1(), ROLE_ADMIN); // Check that the account address has the admin role. - assert(role_store.has_role(account_address, ROLE_ADMIN), 'Invalid role'); + assert(role_store.has_role(account_1(), ROLE_ADMIN), 'Invalid role'); // Revoke admin role from account address. - role_store.revoke_role(account_address, ROLE_ADMIN); + role_store.revoke_role(account_1(), ROLE_ADMIN); // Check that the account address does not have the admin role. - assert(!role_store.has_role(account_address, ROLE_ADMIN), 'Invalid role'); - - // ********************************************************************************************* - // * TEARDOWN * - // ********************************************************************************************* - teardown(); + assert(!role_store.has_role(account_1(), ROLE_ADMIN), 'Invalid role'); } #[test] -fn test_get_role_count() { - // ********************************************************************************************* - // * SETUP * - // ********************************************************************************************* +fn given_normal_conditions_when_get_role_count_then_works() { let role_store = setup(); - // ********************************************************************************************* - // * TEST LOGIC * - // ********************************************************************************************* - // Use the address that has been used to deploy role_store. - let caller_address: ContractAddress = 0x101.try_into().unwrap(); - start_prank(role_store.contract_address, caller_address); - - let account_address: ContractAddress = contract_address_const::<1>(); + start_prank(role_store.contract_address, admin()); // Here, we will test the role count. Initially, it should be 1. assert(role_store.get_role_count() == 1, 'Initial role count should be 1'); // Grant CONTROLLER role to account address. - role_store.grant_role(account_address, CONTROLLER); + role_store.grant_role(account_1(), CONTROLLER); // After granting the role CONTROLLER, the count should be 2. assert(role_store.get_role_count() == 2, 'Role count should be 2'); // Grant MARKET_KEEPER role to account address. - role_store.grant_role(account_address, MARKET_KEEPER); + role_store.grant_role(account_1(), MARKET_KEEPER); // After granting the role MARKET_KEEPER, the count should be 3. assert(role_store.get_role_count() == 3, 'Role count should be 3'); // The ROLE_ADMIN role is already assigned, let's try to reassign it to see if duplicates are managed. // Grant ROLE_ADMIN role to account address. - role_store.grant_role(account_address, ROLE_ADMIN); + role_store.grant_role(account_1(), ROLE_ADMIN); // Duplicates, the count should be 3. assert(role_store.get_role_count() == 3, 'Role count should be 3'); - // ********************************************************************************************* - // * TEARDOWN * - // ********************************************************************************************* - teardown(); + // Revoke a MARKET_KEEPER role, since the role has now no members the roles count + // is decreased. + role_store.revoke_role(account_1(), MARKET_KEEPER); + assert(role_store.get_role_count() == 2, 'Role count should be 2'); +} + +#[test] +fn given_normal_conditions_when_get_roles_then_works() { + let role_store = setup(); + + // Use the address that has been used to deploy role_store. + start_prank(role_store.contract_address, admin()); + + // Grant CONTROLLER role to account address. + role_store.grant_role(account_1(), CONTROLLER); + + // Grant MARKET_KEEPER role to account address. + role_store.grant_role(account_1(), MARKET_KEEPER); + + // Get roles from index 1 to 2 (should return ROLE_ADMIN and CONTROLLER). + // Note: Starknet's storage starts at 1, for this reason storage index 0 will + // always be empty. + let roles_0_to_2 = role_store.get_roles(1, 2); + let first_role = roles_0_to_2.at(0); + let second_role = roles_0_to_2.at(1); + assert(*first_role == ROLE_ADMIN, '1 should be ROLE_ADMIN'); + assert(*second_role == CONTROLLER, '2 should be CONTROLLER'); + + // Get roles from index 2 to 3 (should return CONTROLLER and MARKET_KEEPER). + let roles_1_to_3 = role_store.get_roles(2, 3); + let first_role = roles_1_to_3.at(0); + let second_role = roles_1_to_3.at(1); + assert(*first_role == CONTROLLER, '3 should be CONTROLLER'); + assert(*second_role == MARKET_KEEPER, '4 should be MARKET_KEEPER'); +} + +#[test] +fn given_normal_conditions_when_get_role_member_count_then_works() { + let role_store = setup(); + + // Use the address that has been used to deploy role_store. + start_prank(role_store.contract_address, admin()); + // Grant CONTROLLER role to account address. + role_store.grant_role(account_1(), CONTROLLER); + role_store.grant_role(account_2(), CONTROLLER); + role_store.grant_role(account_3(), CONTROLLER); + + assert(role_store.get_role_member_count(CONTROLLER) == 3, 'members count != 3'); + + role_store.revoke_role(account_3(), CONTROLLER); + assert(role_store.get_role_member_count(CONTROLLER) == 2, 'members count != 2'); + + role_store.revoke_role(account_2(), CONTROLLER); + assert(role_store.get_role_member_count(CONTROLLER) == 1, 'members count != 1'); } -//#[test] -//fn test_get_roles() { -// ********************************************************************************************* -// * SETUP * -// ********************************************************************************************* -//let role_store = setup(); - -// ********************************************************************************************* -// * TEST LOGIC * -// ********************************************************************************************* - -// Use the address that has been used to deploy role_store. -//let caller_address: ContractAddress = 0x101.try_into().unwrap(); -//start_prank(role_store.contract_address, caller_address); - -//let account_address: ContractAddress = contract_address_const::<1>(); - -// Grant CONTROLLER role to account address. -//role_store.grant_role(account_address, CONTROLLER); - -// Grant MARKET_KEEPER role to account address. -//role_store.grant_role(account_address, MARKET_KEEPER); - -// Get roles from index 0 to 2 (should return ROLE_ADMIN and CONTROLLER). -//let roles_0_to_2 = role_store.get_roles(0, 2); -//let first_role = roles_0_to_2.at(0); -//let second_role = roles_0_to_2.at(1); -//assert(*first_role == ROLE_ADMIN, '1 should be ROLE_ADMIN'); -//assert(*second_role == CONTROLLER, '2 should be CONTROLLER'); - -// Get roles from index 1 to 3 (should return CONTROLLER and MARKET_KEEPER). -//let roles_1_to_3 = role_store.get_roles(1, 3); -//let first_role = roles_1_to_3.at(0); -//let second_role = roles_1_to_3.at(1); -//assert(*first_role == CONTROLLER, '3 should be CONTROLLER'); -//assert(*second_role == MARKET_KEEPER, '4 should be MARKET_KEEPER'); - -// ********************************************************************************************* -// * TEARDOWN * -// ********************************************************************************************* -//teardown(); -//} +#[test] +fn given_normal_conditions_when_get_role_members_then_works() { + let role_store = setup(); + + // Use the address that has been used to deploy role_store. + start_prank(role_store.contract_address, admin()); + // Grant CONTROLLER role to accounts. + role_store.grant_role(account_1(), CONTROLLER); + role_store.grant_role(account_2(), CONTROLLER); + role_store.grant_role(account_3(), CONTROLLER); + + let members = role_store.get_role_members(CONTROLLER, 1, 3); + assert(*members.at(0) == account_1(), 'should be acc_1'); + assert(*members.at(1) == account_2(), 'should be acc_2'); + assert(*members.at(2) == account_3(), 'should be acc_3'); + + role_store.revoke_role(account_2(), CONTROLLER); + let members = role_store.get_role_members(CONTROLLER, 1, 2); + assert(*members.at(0) == account_1(), 'should be acc_1'); + assert(*members.at(1) == account_3(), 'should be acc_3'); + + role_store.revoke_role(account_1(), CONTROLLER); + let members = role_store.get_role_members(CONTROLLER, 1, 2); + assert(*members.at(0) == account_3(), 'should be acc_3'); +} + +fn admin() -> ContractAddress { + contract_address_const::<0x101>() +} + +fn account_1() -> ContractAddress { + contract_address_const::<1>() +} + +fn account_2() -> ContractAddress { + contract_address_const::<2>() +} + +fn account_3() -> ContractAddress { + contract_address_const::<3>() +} /// Utility function to setup the test environment. fn setup() -> IRoleStoreDispatcher { IRoleStoreDispatcher { contract_address: deploy_role_store() } } -/// Utility function to teardown the test environment. -fn teardown() {} - // Utility function to deploy a role store contract and return its address. fn deploy_role_store() -> ContractAddress { let contract = declare('RoleStore');