Skip to content

Commit

Permalink
Enable multiple neuron views per notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jessevig committed Apr 2, 2022
1 parent c08bb4f commit 9c94349
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 74 deletions.
27 changes: 14 additions & 13 deletions bertviz/neuron_view.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* 01/16/21 Jesse Vig Dark mode
* 02/06/21 Jesse Vig Move require config from separate jupyter notebook step
* 03/23/22 Daniel SC Update requirement URLs for d3 and jQuery (source of bug not allowing end result to be displayed on browsers)
* 04/02/22 Jesse Vig Enable multiple neuron views per notebook
**/

require.config({
Expand Down Expand Up @@ -83,16 +84,16 @@ requirejs(['jquery', 'd3'],
var keys = attnData.keys[config.layer][config.head];
var att = attnData.attn[config.layer][config.head];

$("#bertviz #vis").empty();
$(`#${config.rootDivId} #vis`).empty();
var height = config.initialTextLength * BOXHEIGHT + HEIGHT_PADDING;
var svg = d3.select("#bertviz #vis")
var svg = d3.select(`#${config.rootDivId} #vis`)
.append('svg')
.attr("width", "100%")
.attr("height", height + "px");

d3.select("#bertviz")
d3.select(`#${config.rootDivId}`)
.style("background-color", getColor('background'));
d3.selectAll("#bertviz .dropdown-label")
d3.selectAll(`#${config.rootDivId} .dropdown-label`)
.style("color", getColor('dropdown'))

renderVisExpanded(svg, leftText, rightText, queries, keys);
Expand Down Expand Up @@ -942,11 +943,11 @@ requirejs(['jquery', 'd3'],

function showCollapsed() {
if (config.index != null) {
var svg = d3.select("#bertviz #vis");
var svg = d3.select(`#${config.rootDivId} #vis`);
highlightSelection(svg, config.index);
}
d3.select("#bertviz #expanded").attr("visibility", "hidden");
d3.select("#bertviz #collapsed").attr("visibility", "visible");
d3.select(`#${config.rootDivId} #expanded`).attr("visibility", "hidden");
d3.select(`#${config.rootDivId} #collapsed`).attr("visibility", "visible");
}

function showExpanded() {
Expand All @@ -955,8 +956,8 @@ requirejs(['jquery', 'd3'],
highlightSelection(svg, config.index);
showComputation(svg, config.index);
}
d3.select("#bertviz #expanded").attr("visibility", "visible");
d3.select("#bertviz #collapsed").attr("visibility", "hidden")
d3.select(`#${config.rootDivId} #expanded`).attr("visibility", "visible");
d3.select(`#${config.rootDivId} #collapsed`).attr("visibility", "hidden")
}

function getColor(name) {
Expand All @@ -977,9 +978,9 @@ requirejs(['jquery', 'd3'],
config.mode = params['display_mode'];
config.layer = (params['layer'] == null ? 0 : params['layer'])
config.head = (params['head'] == null ? 0 : params['head'])
config.rootDivId = params['root_div_id'];


const layerSelect = $("#bertviz #layer");
const layerSelect = $(`#${config.rootDivId} #layer`);
layerSelect.empty();
for (var i = 0; i < config.nLayers; i++) {
layerSelect.append($("<option />").val(i).text(i));
Expand All @@ -990,7 +991,7 @@ requirejs(['jquery', 'd3'],
render();
});

const headSelect = $("#bertviz #att_head");
const headSelect = $(`#${config.rootDivId} #att_head`);
headSelect.empty();
for (var i = 0; i < config.nHeads; i++) {
headSelect.append($("<option />").val(i).text(i));
Expand All @@ -1001,7 +1002,7 @@ requirejs(['jquery', 'd3'],
render();
});

$("#bertviz #filter").on('change', function (e) {
$(`#${config.rootDivId} #filter`).on('change', function (e) {
config.filter = e.currentTarget.value;
render();
});
Expand Down
101 changes: 40 additions & 61 deletions bertviz/neuron_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,87 +34,66 @@
from IPython.core.display import display, HTML, Javascript


def show(model, model_type, tokenizer, sentence_a, sentence_b=None, display_mode='dark', layer=None, head=None, html_action='view'):

# Generate unique div id to enable multiple visualizations in one notebook
def show(model, model_type, tokenizer, sentence_a, sentence_b=None, display_mode='dark', layer=None, head=None,
html_action='view'):

if sentence_b:
vis_html = """
<div id="bertviz" style="padding:8px;font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
<span style="user-select:none">
<span class="dropdown-label">Layer: </span><select id="layer"></select>
<span class="dropdown-label">Head: </span><select id="att_head"></select>
attn_dropdown = """
<span class="dropdown-label">Attention: </span><select id="filter">
<option value="all">All</option>
<option value="aa">Sentence A -> Sentence A</option>
<option value="ab">Sentence A -> Sentence B</option>
<option value="ba">Sentence B -> Sentence A</option>
<option value="bb">Sentence B -> Sentence B</option>
</select>
"""
else:
attn_dropdown = ""

# Generate unique div id to enable multiple visualizations in one notebook
vis_id = 'bertviz-%s' % (uuid.uuid4().hex)
vis_html = f"""
<div id={vis_id} style="padding:8px;font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
<span style="user-select:none">
<span class="dropdown-label">Layer: </span><select id="layer"></select>
<span class="dropdown-label">Head: </span> <select id="att_head"></select>
{attn_dropdown}
</span>
<div id='vis'></div>
</div>
"""
"""

__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
attn_data = get_attention(model, model_type, tokenizer, sentence_a, sentence_b, include_queries_and_keys=True)
if model_type == 'gpt2':
bidirectional = False
else:
vis_html = """
<div id="bertviz" style="padding:8px;font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
<span style="user-select:none">
<span class="dropdown-label">Layer: </span><select id="layer"></select>
<span class="dropdown-label">Head: </span> <select id="att_head"></select>
</span>
<div id='vis'></div>
</div>
"""
bidirectional = True
params = {
'attention': attn_data,
'default_filter': "all",
'root_div_id': vis_id,
'bidirectional': bidirectional,
'display_mode': display_mode,
'layer': layer,
'head': head
}
vis_js = open(os.path.join(__location__, 'neuron_view.js')).read()
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
html2 = HTML(vis_html)

# require.js must be imported for Colab or JupyterLab:
if html_action == 'view':
display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>'))
display(HTML(vis_html))
__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
attn_data = get_attention(model, model_type, tokenizer, sentence_a, sentence_b, include_queries_and_keys=True)
if model_type == 'gpt2':
bidirectional = False
else:
bidirectional = True
params = {
'attention': attn_data,
'default_filter': "all",
'bidirectional': bidirectional,
'display_mode': display_mode,
'layer': layer,
'head': head
}
vis_js = open(os.path.join(__location__, 'neuron_view.js')).read()
display(html1)
display(html2)
display(Javascript('window.bertviz_params = %s' % json.dumps(params)))
display(Javascript(vis_js))

elif html_action == 'return':
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
html2 = HTML(vis_html)
__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
attn_data = get_attention(model, model_type, tokenizer, sentence_a, sentence_b, include_queries_and_keys=True)
if model_type == 'gpt2':
bidirectional = False
else:
bidirectional = True
params = {
'attention': attn_data,
'default_filter': "all",
'bidirectional': bidirectional,
'display_mode': display_mode,
'layer': layer,
'head': head
}
vis_js = open(os.path.join(__location__, 'neuron_view.js')).read()

script1 = '\n<script type="text/javascript">\n' + Javascript('window.bertviz_params = %s' % json.dumps(params)).data + '\n</script>\n'
script2= '\n<script type="text/javascript">\n' + Javascript(vis_js).data + '\n</script>\n'

script1 = '\n<script type="text/javascript">\n' + Javascript(
'window.bertviz_params = %s' % json.dumps(params)).data + '\n</script>\n'
script2 = '\n<script type="text/javascript">\n' + Javascript(vis_js).data + '\n</script>\n'
neuron_html = HTML(html1.data + html2.data + script1 + script2)
return neuron_html

else:
raise ValueError("'html_action' parameter must be 'view' or 'return")

Expand Down

0 comments on commit 9c94349

Please sign in to comment.