Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise postprocessing validation is very slow - possible optimizations #599

Closed
korotaS opened this issue Jun 17, 2024 · 21 comments
Closed
Assignees
Labels

Comments

@korotaS
Copy link
Contributor

korotaS commented Jun 17, 2024

Hi! I've been trying to train a Reranking pairwise model (using this guide and OML version 3.1.0) and it seems to train ok but the validation takes too much time - one epoch of training takes about 5 minutes and one validation cycle takes about 4 hours (the dataset is pretty big). I've used cProfile to profile a part of training and here are the top 100 slowest functions: train_rerank_cprofile.txt. The whole training took about 63 hours (229k seconds), validation is every 10 epochs.

There are 2 things that I've noticed and they definitely can be optimized:

First - look at line ...48866.403...oml/datasets/images.py:281(get_query_ids) and below ...48784.334...images.py:284(get_gallery_ids) (line numbers might not match because I made some other minor changes), so two methods get_query_ids and get_gallery_ids that are called in oml/retrieval/postprocessors/pairwise.py:121-122 take 97k seconds. query_ids and gallery_ids don't change during training, so we can inside ImageQueryGalleryLabeledDataset.__init__ calculate query_ids and gallery_ids once and save about 97k seconds during training in my case.

Second - it is not obvious from the cProfile txt file but from ClearML logs it seemed interesting that PL validation step take about 40 minutes although validation step looks like this:

def validation_step(self, batch: Dict[str, Any], batch_idx: int, *_: Any) -> Dict[str, Any]:
# We simply accumulate batches here since we apply postprocessor during metrics calculation
return batch

We don't use images that are read from disc directly after this step (we use only embeddings), however we load images again in the PairwiseDataset:
self.input_tensors_key_1: self.base_dataset[i1][key],
self.input_tensors_key_2: self.base_dataset[i2][key],

I think we can pass some kind of load_images parameter to the dataset __init__ method and load them only when needed - for example, in training dataset and in pairs dataset, but not in validation dataset (or as alternative - make a wrapper to __getitem__ method and pass this parameter directly there), it can save 40 minutes in each validation in my case.

There were some other issues during training but they aren't related to optimization so I will create another issue.

@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

Hey, @korotaS, thank you for you interest in OML.

A general note on post-processing: it's expected that validation takes longer than training "epoch":

  • first, training "epoch" is not a full epoch if you use samplers like "BalanceSampler", because it iterates over all labels (not over all instances) during the epoch
  • second, we have more inferences in validation and every single of them is slower because images are concatenated
  • thus, I recommend to perform X training cycles, then 1 validation cycle (that you already did with X = 10, but it may be even bigger)

By the way, what is the dataset size and what task are you solving?

@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

get_query_ids, get_gallery_ids

agree, even without post-processing context, it would be a good optimization

but are you really sure that we spent 26h from 63h doing this operation? I'm very skeptical... You can do this optimization locally first and check if numbers are really like this. Anyway, it may be a good contribution, if you would like to contribute.

@AlekseySh AlekseySh self-assigned this Jun 17, 2024
@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

As for the second point. You are right. There is a duplicated image loading when using Lightning.

Basically, we need the first loader (built on LabeledDataset) in validation step only to deliver and accumulate embeddings. So, if we want to remove duplicated loading, we should do it here. The second loader (built on PairsDataset) is okay, because it indeed needs to load images.

I would say, there a few root problems.

  • First, Lightning is not flexible enough, so we need to provide it with some dataset and loader to perform its validation steps.
  • As a result, we used ImageDataset. It delivers embeddings from extra_data container, but it also causes extra image loading.

If you take a look on pure python examples you will see, that there is no duplicated image reading in the validation. (Despite we create val_dataset, we don't iterate over all __getitems__ to build the original retrieval results, we simply use pre-computed embeddings).

See the next comment for the possible solution.

@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

Ideally, we should use EmbeddingsQueryGalleryLabeledDataset for the purpose of delivering and accumulating embeddings. We can move an existing datasets from tests/test_integrations/utils.py to the library core and update it a little bit.

The update is needed in order not to loose visalisation functionality which happens in validation. (When we used image dataset we had visualize() method). We can do it something like this:


def vis(dataset, i):
    return dataset.visualise(i)  # this dataset is used only to surve for visualisation, there is no loop over __getitem__

class EmbeddingsQueryGalleryLabeledDataset:

    def __init__ (..., visualisation_fn: Optional = vis):
        self.vis = vis


    def visualise(self, i):
        return self.vis[i]

# img dataset is needed to compute embeddings and to visualise some items (usually dozens)
paths = ...
img_dataset = ImgDataset(paths) 
embeddings = inference(img_dataset...)
vis = partial(vis, dataset=img_dataset)

# for loop we use dataset of embeddings:
emb_dataset = EmbDataset(embeddings, vis=vis)

If you with you can contribute. The plan is:

  • Move and update embeddings datasets
  • Update the get_loaders_with_embeddings accordingly

@korotaS
Copy link
Contributor Author

korotaS commented Jun 17, 2024

Thanks for the quick reply!

By the way, what is the dataset size and what task are you solving?

I have a trained embedder with cmc@1 at about 0.96 and I want to train a reranker on "hard" items to further improve metrics. This train run was kind of PoC (I didn't pick "hard" items, I used all of them just to see the training process) and the query size was about 37k and gallery size was about 110k.

but are you really sure that we spent 26h from 63h doing this operation?

I am not 100% sure but I tend to believe the numbers that cProfile is outputting, and it says that from 229k seconds of the whole training process 48k+48k=96k seconds were took by get_query_ids and get_gallery_ids. I will locally make this optimization and after that check whether the time of those methods decreased to ~0 seconds.

@AlekseySh
Copy link
Contributor

@korotaS got it, what is the domain? retail?

I will locally make this optimization and after that check whether the time of those methods decreased to ~0 seconds.

yep, it should be quick and informative

see my other comments above and have a nice day :)

@korotaS
Copy link
Contributor Author

korotaS commented Jun 17, 2024

The domain is indeed a retail. In fact, we were on the same meeting with Epoch8 last Tuesday and we were discussing this same problem after your presentation of new features of OML 😁

@AlekseySh
Copy link
Contributor

I thought so, nice to meet you here!

Don't hesitate to join us, especially if you are going to contribute: https://t.me/+lqsKu2af8xcyMjEy

@korotaS
Copy link
Contributor Author

korotaS commented Jun 17, 2024

As for your proposed solution about embeddings dataset - maybe I misunderstood something but I think that it won't work that way. If we create an EmbeddingsQueryGalleryLabeledDataset as a validation dataset, it will break at the pairwise_inference function because it won't know how to load an image. And the problem is that internally in pairwise_inference function the PairDataset is created as a wrapper for the validation dataset, so one depends on another.

We can use EmbeddingsQueryGalleryLabeledDataset as a validation dataset, but at the same time PairDataset needs to be created on top of some dataset that can load images. Maybe we can instantiate an ImageQueryGalleryLabeledDataset, wrap it in PairDataset, pass it to PairwiseReranker.__init__ and then pass this dataset to pairwise_inference function (however we will need to update pair_ids every time).

@AlekseySh
Copy link
Contributor

If we create an EmbeddingsQueryGalleryLabeledDataset as a validation dataset, it will break at the pairwise_inference function because it won't know how to load an image.

Oh, you are right. So, the problem is that we want to use the same dataset first to deliver embeddings, second to be a base for PairsDataset to load pairs of images.

@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

Your solution significantly changes the signatures, but we can have another easy-to-go option to start with. Our datasets support cache on reading images (so, we cache bytes). Unfortunately, cache is not parsed in postprocessing config as it's done in the main training config. I can easily add it in a few lines of code. It's not a full solution (we still need some time to stack decode and stack images), but it should speed up the process a lot.

ps. implemented and merged proposed changes here: #604. I will add it in pypi soon (after tests are green). Could you please check how much it helps?

@AlekseySh
Copy link
Contributor

As for the first issue, the solution is already in the main branch and will be released in pypi soon. Thank you for the report!

@AlekseySh AlekseySh added this to To do in OML-planning via automation Jun 17, 2024
@korotaS
Copy link
Contributor Author

korotaS commented Jun 18, 2024

Could you please check how much it helps?

I've fetched your updates and set cache_size to a very big number (roughly equal to val dataset size), so that ~all of the PairsDataset.__getitem__ calls in val loop are cached, but the validation loop itself (without PairwiseReranker.process) didn't speed up... Maybe the cache_size was too big and it can't work as expected with such bug number?

However, I look in profiling logs and I don't see lru_cache and anything similar, maybe I messed up somewhere

@korotaS
Copy link
Contributor Author

korotaS commented Jun 18, 2024

Or maybe the bottleneck is not in reading images but in transforming them:
image

As you can see, from 2726 seconds that __getitem__ takes 2600 are taken by transforms

@AlekseySh
Copy link
Contributor

Correct me if I'm wrong, but Lest Recently Used cache will not help you even if it's size just a bit smaller than the dataset size. If your dataset's size is 1000 and the cache size is 999, when you start the second epoch (which you do without shuffle) the very first element is the oldest one and has already been removed from the cache.

Our situation is a bit more complicated, but, anyway, could you try yo set cache size bigger than the dataset size?

@AlekseySh
Copy link
Contributor

Anyway, @korotaS , what percentage of speedup are we talking about if we avoid loading images the second time? (You can replace __getitem__ method in a hacky way to understand it or use your load_images parameter)

@korotaS
Copy link
Contributor Author

korotaS commented Jun 18, 2024

When I run validation as is, the embeddings accumulation part (before pairwise_inference) took about 40 minutes out of ~2h20m of full validation. When I used load_images and loaded images only in PairsDataset, 40 minutes decreased to 25 seconds, so the speedup is about 28% (40m / 140m).

I will now try increasing cache_size to be bigger than val dataset size (if val dataset size is 100, than in embeddings accumulation 100 images will be loaded and then in pairwise_inference the same 100 images will be loaded, so we need cache_size to be bigger than val dataset size) and report on any speedup.

@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 18, 2024

@korotaS got it!
So, if we can save ~28% of validation and validation time takes 25% of the full training cycle (val + train), it means we can save 25*0.28=7% of the full time. Not as much. Let's see if bigger cache helps.

The profit from that wrong ids handling was much higher)

Yep, in that case 101 cache size seems enough

@AlekseySh
Copy link
Contributor

@korotaS what should we do with this issue? have you tried bigger cache size in order to speed up the process?

@korotaS
Copy link
Contributor Author

korotaS commented Jul 4, 2024

I have tested the caching function, unfortunately I don't have the numbers now but it didn't show a good speedup. I think that is is because of transforms, as I said earlier:

Or maybe the bottleneck is not in reading images but in transforming them: image

As you can see, from 2726 seconds that __getitem__ takes 2600 are taken by transforms

So loading images is pretty fast, but transforming them is a lot longer.
My approach with load_images isn't perfect and it also saves only 7% of training time, as you calculated, so maybe it is not very necessary to implement this optimization.

@AlekseySh
Copy link
Contributor

Agree, sounds like too much changes for 7%

OML-planning automation moved this from To do to Done Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Development

No branches or pull requests

2 participants