Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please add support for torch.tensor_split #2226

Open
mallman opened this issue May 21, 2024 · 3 comments
Open

Please add support for torch.tensor_split #2226

mallman opened this issue May 21, 2024 · 3 comments
Labels
missing layer type Unable to convert a layer type from the relevant framework PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@mallman
Copy link

mallman commented May 21, 2024

  • Name of layer type: torch.tensor_split
  • Is this a PyTorch or a TensorFlow layer type: PyTorch
  • Your version of coremltools: 7.2
  • Your version of PyTorch/TensorFlow: 2.3.0
  • Impact of supporting this layer type. Why is adding support for this layer type important? Is it necessary to support a popular model or use case?

This layer/op is used by EVA-02, a model for image classification, segmentation and object detection. Personally, I'm interested in using it for image classification in a Mac app.

As of this writing (May 21st, 2024), various sizes of pre-trained EVA and EVA-02 models dominate the leaderboard for image classification on ImageNet 1k among the models curated by the Pytorch Image Models Hugging Face org. See https://huggingface.co/collections/timm/timm-top-20-imagenet-1k-models-655d78909af37bae32381f61

FYI, it looks like this is (essentially) the same op as tf.split from TensorFlow.

@mallman mallman added the missing layer type Unable to convert a layer type from the relevant framework label May 21, 2024
@mallman
Copy link
Author

mallman commented May 21, 2024

Oh, and here's an example of a failing conversion. This is from a script I've written for converting timm models:

import coremltools as ct
import timm
import torch

model_name = "eva02_tiny_patch14_224.mim_in22k"
print(f"Creating model {model_name}")
timm_model = timm.create_model(
  model_name,
  pretrained=True,
  scriptable=False,
  exportable=True)

model = torch.nn.Sequential(
  timm_model,
  torch.nn.Softmax(1)
).eval()

input_size = timm_model.default_cfg.get("input_size")
input_shape = (1,) + input_size

print("Tracing model")
example_input = torch.randn(input_shape)
jit_model = torch.jit.trace(model, example_input)

labels_filename = "imagenet21k_wordnet_lemmas.txt"

with open(labels_filename, "r") as labels_file:
  labels = [line.strip() for line in labels_file.readlines()]

classifier_config = ct.ClassifierConfig(labels)

print("Converting model")
# Scale and bias calculations taken from Core ML Tools documentation on
# preprocessing for PyTorch
mean = list(timm_model.default_cfg.get("mean"))
std = list(timm_model.default_cfg.get("std"))
import statistics
mean_std = statistics.mean(std)
scale = 1 / (mean_std * 255)
bias = [-m / s for m, s in zip(mean, std)]
input_type = ct.ImageType(
      name="image",
      shape=input_shape,
      scale=scale,
      bias=bias)

coreml_model = ct.convert(
  jit_model,
  convert_to="mlprogram",
  inputs=[input_type],
  classifier_config=classifier_config,
  skip_model_load=True
)

coreml_model.user_defined_metadata["com.apple.coreml.model.preview.type"] = "imageClassifier"

coreml_model_file_name = f"{model_name}.mlpackage"
print(f"Saving model to {coreml_model_file_name}")

coreml_model.save(coreml_model_file_name)
print("Done!")

I believe a pip install with the timm, torch and coremltools packages will give you the right environment for running this.

You will also need a labels file, imagenet21k_wordnet_lemmas.txt, in your working directory. I'm attaching that file.
imagenet21k_wordnet_lemmas.txt

@TobyRoseman
Copy link
Collaborator

Here is a more concise way to reproduce the issue:

import torch
import coremltools as ct

class M(torch.nn.Module):
    def forward(self, x):
        return torch.tensor_split(x, 3)

x = torch.arange(8)
traced_model = torch.jit.trace(M(), x)
ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape)])

I think we should be able to use the split MIL ops at least for simple cases.

@TobyRoseman TobyRoseman added triaged Reviewed and examined, release as been assigned if applicable (status) PyTorch (traced) labels May 22, 2024
@teelrabbit
Copy link
Contributor

Looks like this can be worked around by not just using torch.split but also using torch.unbind as shown here

An example of this being implemented can be seen below or in this paste (https://pastes.dev/kkaPViedJ7)

import torch
import coremltools as ct

class M(torch.nn.Module):
    def forward(self, x):
        splits = torch.split(x, x.size(0) // 3)
        return torch.unbind(torch.stack(splits))

x = torch.arange(9)  
traced_model = torch.jit.trace(M(), x)
ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape)])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
missing layer type Unable to convert a layer type from the relevant framework PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

3 participants