From 0cf1e888055edcdeb7d07826934da911f42f833a Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 21 Mar 2024 16:33:33 -0400 Subject: [PATCH] Small updates to tests (#642) --- tests/requirements.txt | 1 - .../test_clusters/test_cluster.py | 20 ++++++++++++++----- tests/test_resources/test_resource.py | 14 ++++++++++--- tests/utils.py | 6 ++++++ 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 26b2bc86e..5f56b7477 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -8,7 +8,6 @@ tqdm # packages for local and unit tests boto3 -skypilot[docker] docker s3fs pandas diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index d26c28be7..3ef70f96e 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -15,7 +15,7 @@ import tests.test_resources.test_resource from tests.conftest import init_args -from tests.utils import get_random_str +from tests.utils import get_random_str, remove_config_keys """ TODO: 1) In subclasses, test factory methods create same type as parent @@ -160,7 +160,7 @@ def test_cluster_endpoint(self, cluster): headers=rh.globals.rns_client.request_headers(), ) assert r.status_code == 200 - assert r.json()["resource_subtype"] == "Cluster" + assert r.json()["resource_type"] == "cluster" @pytest.mark.level("local") def test_cluster_objects(self, cluster): @@ -292,14 +292,24 @@ def test_rh_status_stopped(self, cluster): def test_condensed_config_for_cluster(self, cluster): import ast - return_codes = cluster.run_python(["import runhouse as rh", "print(rh.here)"]) + return_codes = cluster.run_python( + ["import runhouse as rh", "print(rh.here)"], stream_logs=False + ) assert return_codes[0][0] == 0 on_cluster_config = ast.literal_eval(return_codes[0][1]) cluster_config = cluster.config() - on_cluster_config.pop("creds", None) - cluster_config.pop("creds", None) + keys_to_skip = ["creds", "client_port", "server_host"] + on_cluster_config = remove_config_keys(on_cluster_config, keys_to_skip) + cluster_config = remove_config_keys(cluster_config, keys_to_skip) + + if cluster_config.get("stable_internal_external_ips", False): + cluster_ips = cluster_config.pop("stable_internal_external_ips", None)[0] + on_cluster_ips = on_cluster_config.pop( + "stable_internal_external_ips", None + )[0] + assert tuple(cluster_ips) == tuple(on_cluster_ips) assert on_cluster_config == cluster_config diff --git a/tests/test_resources/test_resource.py b/tests/test_resources/test_resource.py index 281ddeed7..f767881be 100644 --- a/tests/test_resources/test_resource.py +++ b/tests/test_resources/test_resource.py @@ -4,7 +4,7 @@ import runhouse as rh from tests.conftest import init_args -from tests.utils import friend_account +from tests.utils import friend_account, remove_config_keys def load_shared_resource_config(resource_class_name, address): @@ -74,10 +74,18 @@ def test_from_config(self, resource): def test_save_and_load(self, saved_resource): # Test loading from name loaded_resource = saved_resource.__class__.from_name(saved_resource.rns_address) - assert loaded_resource.config() == saved_resource.config() - # Changing the name doesn't work for OnDemandCluster, because the name won't match the local sky db + if isinstance(saved_resource, rh.OnDemandCluster): + loaded_resource_config = remove_config_keys( + loaded_resource.config(), ["stable_internal_external_ips"] + ) + saved_resource_config = remove_config_keys( + saved_resource.config(), ["stable_internal_external_ips"] + ) + assert loaded_resource_config == saved_resource_config + # Changing the name doesn't work for OnDemandCluster, because the name won't match the local sky db return + assert loaded_resource.config() == saved_resource.config() # Do everything inside a try/finally so we don't leave resources behind if the test fails try: diff --git a/tests/utils.py b/tests/utils.py index 653043459..316be9cf6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -72,3 +72,9 @@ def test_env(logged_in=False): else False, name="base_env", ) + + +def remove_config_keys(config, keys_to_skip): + for key in keys_to_skip: + config.pop(key, None) + return config