Skip to content

Commit

Permalink
Update README.md Zero Shot Code block (#137)
Browse files Browse the repository at this point in the history
* Update README.md

Changing facebook/bart-large-mnli to cross-encoder/nli-deberta-base works for Zero Shot classification Interpret. @cdpierse

* Add files via upload

* Update README.md
  • Loading branch information
Owaiskhan9654 committed Aug 30, 2023
1 parent 30dfe0a commit 7c2f938
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 75 deletions.
150 changes: 75 additions & 75 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,18 @@ This explainer allows for attributions to be calculated for zero shot classifica

Let's start by initializing a transformers' sequence classification model and tokenizer trained specifically on a NLI task, and passing it to the ZeroShotClassificationExplainer.

For this example we are using `facebook/bart-large-mnli` which is a checkpoint for a bart-large model trained on the
[MNLI dataset](https://huggingface.co/datasets/multi_nli). This model typically predicts whether a sentence pair are an entailment, neutral, or a contradiction, however for zero-shot we only look the entailment label.
For this example we are using `cross-encoder/nli-deberta-base` which is a checkpoint for a deberta-base model trained on the
[SNLI](https://nlp.stanford.edu/projects/snli/) and [NLI dataset](https://huggingface.co/datasets/multi_nli) Datasets. This model typically predicts whether a sentence pair are an entailment, neutral, or a contradiction, however for zero-shot we only look the entailment label.

Notice that we pass our own custom labels `["finance", "technology", "sports"]` to the class instance. Any number of labels can be passed including as little as one. Whichever label scores highest for entailment can be accessed via `predicted_label`, however the attributions themselves are calculated for every label. If you want to see the attributions for a particular label it is recommended just to pass in that one label and then the attributions will be guaranteed to be calculated w.r.t. that label.

```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/nli-deberta-base")

model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/nli-deberta-base")


zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)
Expand All @@ -508,75 +508,75 @@ Which will return the following dict of attribution tuple lists for each label:

```python
>>> word_attributions
{'finance': [('<s>', 0.0),
('Today', 0.0),
('apple', -0.016100065046282107),
('released', 0.3348383988281792),
('the', -0.8932952916127369),
('new', 0.14207183688642497),
('Mac', 0.016309545780430777),
('book', -0.06956802041125129),
('showing', -0.12661404114316252),
('off', -0.11470154900720078),
('a', -0.03299250484912159),
('range', -0.002532332125100561),
('of', -0.022451943898971004),
('new', -0.01859870581213379),
('features', -0.020774327263810944),
('found', -0.007734346326330102),
('in', 0.005100588658589585),
('the', 0.04711084622588314),
('proprietary', 0.046352064964644286),
('silicon', -0.0033502000158946127),
('chip', -0.010419324929115785),
('computer', -0.11507972995022273),
('.', 0.12237840300907425)],
'technology': [('<s>', 0.0),
('Today', 0.0),
('apple', 0.22505152647747717),
('released', -0.16164146624851905),
('the', 0.5026975657258089),
('new', 0.052589263167955536),
('Mac', 0.2528325960993759),
('book', -0.06445090203729663),
('showing', -0.21204922293777534),
('off', 0.06319714817612732),
('a', 0.032048012090796815),
('range', 0.08553079346908955),
('of', 0.1409201107994034),
('new', 0.0515261917112576),
('features', -0.09656406466213506),
('found', 0.02336613296843605),
('in', -0.0011649894272190678),
('the', 0.14229640664777807),
('proprietary', -0.23169065661847646),
('silicon', 0.5963924257008087),
('chip', -0.19908474233975806),
('computer', 0.030620295844734646),
('.', 0.1995076958535378)],
'sports': [('<s>', 0.0),
('Today', 0.0),
('apple', 0.1776618164760026),
('released', 0.10067773539491479),
('the', 0.4813466937627506),
('new', -0.018555244191949295),
('Mac', 0.016338241133536224),
('book', 0.39311969562943677),
('showing', 0.03579210145504227),
('off', 0.0016710813632476176),
('a', 0.04367940034297261),
('range', 0.06076859006993011),
('of', 0.11039711284328052),
('new', 0.003932416031994724),
('features', -0.009660883377622588),
('found', -0.06507586539836184),
('in', 0.2957812911667922),
('the', 0.1584106228974514),
('proprietary', 0.0005789280604917397),
('silicon', -0.04693795680472678),
('chip', -0.1699508539245465),
('computer', -0.4290823663975582),
('.', 0.469314992542427)]}
{'finance': [('[CLS]', 0.0),
('Today', 0.144761198095125),
('apple', 0.05008283286211926),
('released', -0.29790757134109724),
('the', -0.09931162582050683),
('new', -0.151252730475885),
('Mac', 0.19431968978659608),
('book', 0.059431761386793486),
('showing', -0.30754747734942633),
('off', 0.0329034397830471),
('a', 0.04198035048519715),
('range', -0.00413947940202566),
('of', 0.7135069733740484),
('new', 0.2294990755900286),
('features', -0.1523457769188503),
('found', -0.016804346228170633),
('in', 0.1185751939327566),
('the', -0.06990875734316043),
('proprietary', 0.16339657649559983),
('silicon', 0.20461302470245252),
('chip', 0.033304742383885574),
('computer', -0.058821677910955064),
('.', -0.19741292299059068)],
'technology': [('[CLS]', 0.0),
('Today', 0.1261355373492264),
('apple', -0.06735584800073911),
('released', -0.37758515332894504),
('the', -0.16300368060788886),
('new', -0.1698884472100767),
('Mac', 0.41505959302727347),
('book', 0.321276307285395),
('showing', -0.2765988420377037),
('off', 0.19388699112601515),
('a', -0.044676708673846766),
('range', 0.05333370699507288),
('of', 0.3654053610507722),
('new', 0.3143976769670845),
('features', 0.2108588137592185),
('found', 0.004676960337191403),
('in', 0.008026783104605233),
('the', -0.09961358108721637),
('proprietary', 0.18816708356062326),
('silicon', 0.13322691438800874),
('chip', 0.015141805082331294),
('computer', -0.1321895049108681),
('.', -0.17152401596638975)],
'sports': [('[CLS]', 0.0),
('Today', 0.11751821789941418),
('apple', -0.024552367058659215),
('released', -0.44706064525430567),
('the', -0.10163968191086448),
('new', -0.18590036257614642),
('Mac', 0.0021649499897370725),
('book', 0.009141161101058446),
('showing', -0.3073791152936541),
('off', 0.0711051596941137),
('a', 0.04153236257439005),
('range', 0.01598478741712663),
('of', 0.6632118834641558),
('new', 0.2684728052423898),
('features', -0.10249856013919137),
('found', -0.032459999377294144),
('in', 0.11078761617308391),
('the', -0.020530085754695244),
('proprietary', 0.17968209761431955),
('silicon', 0.19997909769476027),
('chip', 0.04447720580439545),
('computer', 0.018515748463790047),
('.', -0.1686603393466192)]}
```

We can find out which label was predicted with:
Expand All @@ -594,8 +594,8 @@ For the `ZeroShotClassificationExplainer` the visualize() method returns a table
zero_shot_explainer.visualize("zero_shot.html")
```

<a href="https://github.com/cdpierse/transformers-interpret/blob/master/images/zero_shot_example.png">
<img src="https://github.com/cdpierse/transformers-interpret/blob/master/images/zero_shot_example.png" width="100%" height="100%" align="center" />
<a href="https://github.com/cdpierse/transformers-interpret/blob/master/images/zero_shot_example2.png">
<img src="https://github.com/cdpierse/transformers-interpret/blob/master/images/zero_shot_example2.png" width="100%" height="100%" align="center" />
</a>

</details>
Expand Down
Binary file added images/zero_shot_example2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 7c2f938

Please sign in to comment.