04a6333
chenzhiqing 3 years ago
4 changed file(s) with 1117 addition(s) and 24 deletion(s). Raw diff Collapse all Expand all
22 {
33 "cell_type": "code",
44 "execution_count": 1,
5 "id": "e468c4d7",
5 "id": "3f8becd7",
66 "metadata": {},
77 "outputs": [
88 {
99 "name": "stderr",
1010 "output_type": "stream",
1111 "text": [
12 "2022-06-22 10:15:14.085329: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
13 "2022-06-22 10:15:14.085364: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
12 "2022-06-22 16:01:19.315518: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
13 "2022-06-22 16:01:19.315554: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
1414 ]
1515 }
1616 ],
2727 {
2828 "cell_type": "code",
2929 "execution_count": 2,
30 "id": "b88c91d8",
30 "id": "1e193909",
3131 "metadata": {},
3232 "outputs": [
3333 {
3434 "data": {
3535 "application/vnd.jupyter.widget-view+json": {
36 "model_id": "2aac41dcbabc4b488d73cbf4efc48ae5",
36 "model_id": "d64bc197cae04268b5496da5d322da16",
3737 "version_major": 2,
3838 "version_minor": 0
3939 },
5454 {
5555 "data": {
5656 "application/vnd.jupyter.widget-view+json": {
57 "model_id": "4ac0b9f1c76841d6a01ecbcec1151547",
57 "model_id": "fcab18055e1b47038d1102f8e250b25b",
5858 "version_major": 2,
5959 "version_minor": 0
6060 },
7575 {
7676 "data": {
7777 "application/vnd.jupyter.widget-view+json": {
78 "model_id": "ede9398befdf4f4aa1da5e73ada4f558",
78 "model_id": "b00264e7496f4f0cbae704a96d5f52d1",
7979 "version_major": 2,
8080 "version_minor": 0
8181 },
9696 {
9797 "data": {
9898 "application/vnd.jupyter.widget-view+json": {
99 "model_id": "1253d25503bc486dbae2516355c1a7d7",
99 "model_id": "94212645b9954e439ce10238528af598",
100100 "version_major": 2,
101101 "version_minor": 0
102102 },
127127 {
128128 "cell_type": "code",
129129 "execution_count": 3,
130 "id": "97b3f806",
130 "id": "7c8bc1e6",
131131 "metadata": {},
132132 "outputs": [],
133133 "source": [
144144 {
145145 "cell_type": "code",
146146 "execution_count": 4,
147 "id": "8ff0f445",
147 "id": "38495b44",
148148 "metadata": {},
149149 "outputs": [],
150150 "source": [
166166 {
167167 "cell_type": "code",
168168 "execution_count": 5,
169 "id": "28558fef",
169 "id": "7ce54831",
170170 "metadata": {},
171171 "outputs": [],
172172 "source": [
181181 {
182182 "cell_type": "code",
183183 "execution_count": 6,
184 "id": "f1759896",
184 "id": "497772e6",
185185 "metadata": {},
186186 "outputs": [],
187187 "source": [
197197 },
198198 {
199199 "cell_type": "code",
200 "execution_count": 8,
201 "id": "c8ce3a9a",
200 "execution_count": 7,
201 "id": "158995c0",
202202 "metadata": {},
203203 "outputs": [
204204 {
205205 "name": "stdout",
206206 "output_type": "stream",
207207 "text": [
208 "Runned time: 1.788 s\n",
208 "Runned time: 1.17 s\n",
209209 "Answer : woman and dog\n"
210210 ]
211211 },
215215 "{'Answer': 'woman and dog'}"
216216 ]
217217 },
218 "execution_count": 8,
218 "execution_count": 7,
219219 "metadata": {},
220220 "output_type": "execute_result"
221221 }
223223 "source": [
224224 "handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})"
225225 ]
226 },
227 {
228 "cell_type": "code",
229 "execution_count": null,
230 "id": "c1668fb5",
231 "metadata": {},
232 "outputs": [],
233 "source": []
234226 }
235227 ],
236228 "metadata": {
0 {
1 "cells": [
2 {
3 "cell_type": "code",
4 "execution_count": 3,
5 "id": "feb398ed",
6 "metadata": {},
7 "outputs": [
8 {
9 "name": "stdout",
10 "output_type": "stream",
11 "text": [
12 "Looking in indexes: https://pypi.doubanio.com/simple/\n",
13 "Collecting timm==0.4.12\n",
14 " Downloading https://pypi.doubanio.com/packages/90/fc/606bc5cf46acac3aa9bd179b3954433c026aaf88ea98d6b19f5d14c336da/timm-0.4.12-py3-none-any.whl (376 kB)\n",
15 "\u001b[K |████████████████████████████████| 376 kB 1.7 MB/s eta 0:00:01\n",
16 "\u001b[?25hCollecting transformers==4.15.0\n",
17 " Downloading https://pypi.doubanio.com/packages/4a/7f/f1c28621af0d74794b18cbe5534ec7565ee782ba48257d08ec264bc4aacb/transformers-4.15.0-py3-none-any.whl (3.4 MB)\n",
18 "\u001b[K |████████████████████████████████| 3.4 MB 1.9 MB/s eta 0:00:01\n",
19 "\u001b[?25hCollecting fairscale==0.4.4\n",
20 " Downloading https://pypi.doubanio.com/packages/9f/51/9b8406605333f7d0a2e6f6a4af29ff64cf6c597b056411c1ed43c35e32b8/fairscale-0.4.4.tar.gz (235 kB)\n",
21 "\u001b[K |████████████████████████████████| 235 kB 59.4 MB/s eta 0:00:01\n",
22 "\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n",
23 "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
24 "\u001b[?25h Installing backend dependencies ... \u001b[?25ldone\n",
25 "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n",
26 "\u001b[?25hCollecting pycocoevalcap\n",
27 " Downloading https://pypi.doubanio.com/packages/08/f9/466f289f1628296b5e368940f89e3cfcfb066d15ddc02ff536dc532b1c93/pycocoevalcap-1.2-py3-none-any.whl (104.3 MB)\n",
28 "\u001b[K |████▍ | 14.2 MB 7.9 MB/s eta 0:00:127 kB/s eta 0:01:4544\u001b[31mERROR: Exception:\n",
29 "Traceback (most recent call last):\n",
30 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/urllib3/response.py\", line 438, in _error_catcher\n",
31 " yield\n",
32 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/urllib3/response.py\", line 519, in read\n",
33 " data = self._fp.read(amt) if not fp_closed else b\"\"\n",
34 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/cachecontrol/filewrapper.py\", line 62, in read\n",
35 " data = self.__fp.read(amt)\n",
36 " File \"/usr/lib/python3.7/http/client.py\", line 461, in read\n",
37 " n = self.readinto(b)\n",
38 " File \"/usr/lib/python3.7/http/client.py\", line 505, in readinto\n",
39 " n = self.fp.readinto(b)\n",
40 " File \"/usr/lib/python3.7/socket.py\", line 589, in readinto\n",
41 " return self._sock.recv_into(b)\n",
42 " File \"/usr/lib/python3.7/ssl.py\", line 1071, in recv_into\n",
43 " return self.read(nbytes, buffer)\n",
44 " File \"/usr/lib/python3.7/ssl.py\", line 929, in read\n",
45 " return self._sslobj.read(len, buffer)\n",
46 "socket.timeout: The read operation timed out\n",
47 "\n",
48 "During handling of the above exception, another exception occurred:\n",
49 "\n",
50 "Traceback (most recent call last):\n",
51 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/cli/base_command.py\", line 180, in _main\n",
52 " status = self.run(options, args)\n",
53 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/cli/req_command.py\", line 205, in wrapper\n",
54 " return func(self, options, args)\n",
55 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/commands/install.py\", line 319, in run\n",
56 " reqs, check_supported_wheels=not options.target_dir\n",
57 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/resolver.py\", line 128, in resolve\n",
58 " requirements, max_rounds=try_to_avoid_resolution_too_deep\n",
59 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/resolvelib/resolvers.py\", line 473, in resolve\n",
60 " state = resolution.resolve(requirements, max_rounds=max_rounds)\n",
61 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/resolvelib/resolvers.py\", line 341, in resolve\n",
62 " name, crit = self._merge_into_criterion(r, parent=None)\n",
63 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/resolvelib/resolvers.py\", line 172, in _merge_into_criterion\n",
64 " if not criterion.candidates:\n",
65 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/resolvelib/structs.py\", line 139, in __bool__\n",
66 " return bool(self._sequence)\n",
67 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/found_candidates.py\", line 143, in __bool__\n",
68 " return any(self)\n",
69 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/found_candidates.py\", line 129, in <genexpr>\n",
70 " return (c for c in iterator if id(c) not in self._incompatible_ids)\n",
71 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/found_candidates.py\", line 33, in _iter_built\n",
72 " candidate = func()\n",
73 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/factory.py\", line 205, in _make_candidate_from_link\n",
74 " version=version,\n",
75 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/candidates.py\", line 312, in __init__\n",
76 " version=version,\n",
77 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/candidates.py\", line 151, in __init__\n",
78 " self.dist = self._prepare()\n",
79 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/candidates.py\", line 234, in _prepare\n",
80 " dist = self._prepare_distribution()\n",
81 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/resolution/resolvelib/candidates.py\", line 318, in _prepare_distribution\n",
82 " self._ireq, parallel_builds=True\n",
83 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/operations/prepare.py\", line 508, in prepare_linked_requirement\n",
84 " return self._prepare_linked_requirement(req, parallel_builds)\n",
85 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/operations/prepare.py\", line 552, in _prepare_linked_requirement\n",
86 " self.download_dir, hashes\n",
87 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/operations/prepare.py\", line 243, in unpack_url\n",
88 " hashes=hashes,\n",
89 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/operations/prepare.py\", line 102, in get_http_url\n",
90 " from_path, content_type = download(link, temp_dir.path)\n",
91 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/network/download.py\", line 157, in __call__\n",
92 " for chunk in chunks:\n",
93 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/cli/progress_bars.py\", line 152, in iter\n",
94 " for x in it:\n",
95 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_internal/network/utils.py\", line 86, in response_chunks\n",
96 " decode_content=False,\n",
97 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/urllib3/response.py\", line 576, in stream\n",
98 " data = self.read(amt=amt, decode_content=decode_content)\n",
99 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/urllib3/response.py\", line 541, in read\n",
100 " raise IncompleteRead(self._fp_bytes_read, self.length_remaining)\n",
101 " File \"/usr/lib/python3.7/contextlib.py\", line 130, in __exit__\n",
102 " self.gen.throw(type, value, traceback)\n",
103 " File \"/home/jovyan/work/.localenv/lib/python3.7/site-packages/pip/_vendor/urllib3/response.py\", line 443, in _error_catcher\n",
104 " raise ReadTimeoutError(self._pool, None, \"Read timed out.\")\n",
105 "pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='pypi.doubanio.com', port=443): Read timed out.\u001b[0m\n",
106 "\u001b[33mWARNING: You are using pip version 21.1.3; however, version 22.1.2 is available.\n",
107 "You should consider upgrading via the '/home/jovyan/.virtualenvs/basenv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
108 ]
109 }
110 ],
111 "source": [
112 "!/home/jovyan/.virtualenvs/basenv/bin/pip install -r requirement.txt -i https://pypi.doubanio.com/simple/"
113 ]
114 },
115 {
116 "cell_type": "code",
117 "execution_count": null,
118 "id": "2b93eb8f",
119 "metadata": {},
120 "outputs": [],
121 "source": []
122 }
123 ],
124 "metadata": {
125 "kernelspec": {
126 "display_name": "Python 3 (ipykernel)",
127 "language": "python",
128 "name": "python3"
129 },
130 "language_info": {
131 "codemirror_mode": {
132 "name": "ipython",
133 "version": 3
134 },
135 "file_extension": ".py",
136 "mimetype": "text/x-python",
137 "name": "python",
138 "nbconvert_exporter": "python",
139 "pygments_lexer": "ipython3",
140 "version": "3.7.5"
141 }
142 },
143 "nbformat": 4,
144 "nbformat_minor": 5
145 }
0 '''
1 * Copyright (c) 2022, salesforce.com, inc.
2 * All rights reserved.
3 * SPDX-License-Identifier: BSD-3-Clause
4 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 * By Junnan Li
6 * Based on huggingface code base
7 * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
8 '''
9
10 import math
11 import os
12 import warnings
13 from dataclasses import dataclass
14 from typing import Optional, Tuple
15
16 import torch
17 from torch import Tensor, device, dtype, nn
18 import torch.utils.checkpoint
19 from torch import nn
20 from torch.nn import CrossEntropyLoss
21 import torch.nn.functional as F
22
23 from transformers.activations import ACT2FN
24 from transformers.file_utils import (
25 ModelOutput,
26 )
27 from transformers.modeling_outputs import (
28 BaseModelOutputWithPastAndCrossAttentions,
29 BaseModelOutputWithPoolingAndCrossAttentions,
30 CausalLMOutputWithCrossAttentions,
31 MaskedLMOutput,
32 MultipleChoiceModelOutput,
33 NextSentencePredictorOutput,
34 QuestionAnsweringModelOutput,
35 SequenceClassifierOutput,
36 TokenClassifierOutput,
37 )
38 from transformers.modeling_utils import (
39 PreTrainedModel,
40 apply_chunking_to_forward,
41 find_pruneable_heads_and_indices,
42 prune_linear_layer,
43 )
44 from transformers.utils import logging
45 from transformers.models.bert.configuration_bert import BertConfig
46
47
48 logger = logging.get_logger(__name__)
49
50
51 class BertEmbeddings(nn.Module):
52 """Construct the embeddings from word and position embeddings."""
53
54 def __init__(self, config):
55 super().__init__()
56 self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
57 self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
58
59 # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
60 # any TensorFlow checkpoint file
61 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
62 self.dropout = nn.Dropout(config.hidden_dropout_prob)
63
64 # position_ids (1, len position emb) is contiguous in memory and exported when serialized
65 self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
66 self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
67
68 self.config = config
69
70 def forward(
71 self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
72 ):
73 if input_ids is not None:
74 input_shape = input_ids.size()
75 else:
76 input_shape = inputs_embeds.size()[:-1]
77
78 seq_length = input_shape[1]
79
80 if position_ids is None:
81 position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
82
83 if inputs_embeds is None:
84 inputs_embeds = self.word_embeddings(input_ids)
85
86 embeddings = inputs_embeds
87
88 if self.position_embedding_type == "absolute":
89 position_embeddings = self.position_embeddings(position_ids)
90 embeddings += position_embeddings
91 embeddings = self.LayerNorm(embeddings)
92 embeddings = self.dropout(embeddings)
93 return embeddings
94
95
96 class BertSelfAttention(nn.Module):
97 def __init__(self, config, is_cross_attention):
98 super().__init__()
99 self.config = config
100 if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
101 raise ValueError(
102 "The hidden size (%d) is not a multiple of the number of attention "
103 "heads (%d)" % (config.hidden_size, config.num_attention_heads)
104 )
105
106 self.num_attention_heads = config.num_attention_heads
107 self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
108 self.all_head_size = self.num_attention_heads * self.attention_head_size
109
110 self.query = nn.Linear(config.hidden_size, self.all_head_size)
111 if is_cross_attention:
112 self.key = nn.Linear(config.encoder_width, self.all_head_size)
113 self.value = nn.Linear(config.encoder_width, self.all_head_size)
114 else:
115 self.key = nn.Linear(config.hidden_size, self.all_head_size)
116 self.value = nn.Linear(config.hidden_size, self.all_head_size)
117
118 self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
119 self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
120 if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
121 self.max_position_embeddings = config.max_position_embeddings
122 self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
123 self.save_attention = False
124
125 def save_attn_gradients(self, attn_gradients):
126 self.attn_gradients = attn_gradients
127
128 def get_attn_gradients(self):
129 return self.attn_gradients
130
131 def save_attention_map(self, attention_map):
132 self.attention_map = attention_map
133
134 def get_attention_map(self):
135 return self.attention_map
136
137 def transpose_for_scores(self, x):
138 new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
139 x = x.view(*new_x_shape)
140 return x.permute(0, 2, 1, 3)
141
142 def forward(
143 self,
144 hidden_states,
145 attention_mask=None,
146 head_mask=None,
147 encoder_hidden_states=None,
148 encoder_attention_mask=None,
149 past_key_value=None,
150 output_attentions=False,
151 ):
152 mixed_query_layer = self.query(hidden_states)
153
154 # If this is instantiated as a cross-attention module, the keys
155 # and values come from an encoder; the attention mask needs to be
156 # such that the encoder's padding tokens are not attended to.
157 is_cross_attention = encoder_hidden_states is not None
158
159 if is_cross_attention:
160 key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
161 value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
162 attention_mask = encoder_attention_mask
163 elif past_key_value is not None:
164 key_layer = self.transpose_for_scores(self.key(hidden_states))
165 value_layer = self.transpose_for_scores(self.value(hidden_states))
166 key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
167 value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
168 else:
169 key_layer = self.transpose_for_scores(self.key(hidden_states))
170 value_layer = self.transpose_for_scores(self.value(hidden_states))
171
172 query_layer = self.transpose_for_scores(mixed_query_layer)
173
174 past_key_value = (key_layer, value_layer)
175
176 # Take the dot product between "query" and "key" to get the raw attention scores.
177 attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
178
179 if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
180 seq_length = hidden_states.size()[1]
181 position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
182 position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
183 distance = position_ids_l - position_ids_r
184 positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
185 positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
186
187 if self.position_embedding_type == "relative_key":
188 relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
189 attention_scores = attention_scores + relative_position_scores
190 elif self.position_embedding_type == "relative_key_query":
191 relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
192 relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
193 attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
194
195 attention_scores = attention_scores / math.sqrt(self.attention_head_size)
196 if attention_mask is not None:
197 # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
198 attention_scores = attention_scores + attention_mask
199
200 # Normalize the attention scores to probabilities.
201 attention_probs = nn.Softmax(dim=-1)(attention_scores)
202
203 if is_cross_attention and self.save_attention:
204 self.save_attention_map(attention_probs)
205 attention_probs.register_hook(self.save_attn_gradients)
206
207 # This is actually dropping out entire tokens to attend to, which might
208 # seem a bit unusual, but is taken from the original Transformer paper.
209 attention_probs_dropped = self.dropout(attention_probs)
210
211 # Mask heads if we want to
212 if head_mask is not None:
213 attention_probs_dropped = attention_probs_dropped * head_mask
214
215 context_layer = torch.matmul(attention_probs_dropped, value_layer)
216
217 context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
218 new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
219 context_layer = context_layer.view(*new_context_layer_shape)
220
221 outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
222
223 outputs = outputs + (past_key_value,)
224 return outputs
225
226
227 class BertSelfOutput(nn.Module):
228 def __init__(self, config):
229 super().__init__()
230 self.dense = nn.Linear(config.hidden_size, config.hidden_size)
231 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
232 self.dropout = nn.Dropout(config.hidden_dropout_prob)
233
234 def forward(self, hidden_states, input_tensor):
235 hidden_states = self.dense(hidden_states)
236 hidden_states = self.dropout(hidden_states)
237 hidden_states = self.LayerNorm(hidden_states + input_tensor)
238 return hidden_states
239
240
241 class BertAttention(nn.Module):
242 def __init__(self, config, is_cross_attention=False):
243 super().__init__()
244 self.self = BertSelfAttention(config, is_cross_attention)
245 self.output = BertSelfOutput(config)
246 self.pruned_heads = set()
247
248 def prune_heads(self, heads):
249 if len(heads) == 0:
250 return
251 heads, index = find_pruneable_heads_and_indices(
252 heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
253 )
254
255 # Prune linear layers
256 self.self.query = prune_linear_layer(self.self.query, index)
257 self.self.key = prune_linear_layer(self.self.key, index)
258 self.self.value = prune_linear_layer(self.self.value, index)
259 self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
260
261 # Update hyper params and store pruned heads
262 self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
263 self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
264 self.pruned_heads = self.pruned_heads.union(heads)
265
266 def forward(
267 self,
268 hidden_states,
269 attention_mask=None,
270 head_mask=None,
271 encoder_hidden_states=None,
272 encoder_attention_mask=None,
273 past_key_value=None,
274 output_attentions=False,
275 ):
276 self_outputs = self.self(
277 hidden_states,
278 attention_mask,
279 head_mask,
280 encoder_hidden_states,
281 encoder_attention_mask,
282 past_key_value,
283 output_attentions,
284 )
285 attention_output = self.output(self_outputs[0], hidden_states)
286 outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
287 return outputs
288
289
290 class BertIntermediate(nn.Module):
291 def __init__(self, config):
292 super().__init__()
293 self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
294 if isinstance(config.hidden_act, str):
295 self.intermediate_act_fn = ACT2FN[config.hidden_act]
296 else:
297 self.intermediate_act_fn = config.hidden_act
298
299 def forward(self, hidden_states):
300 hidden_states = self.dense(hidden_states)
301 hidden_states = self.intermediate_act_fn(hidden_states)
302 return hidden_states
303
304
305 class BertOutput(nn.Module):
306 def __init__(self, config):
307 super().__init__()
308 self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
309 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
310 self.dropout = nn.Dropout(config.hidden_dropout_prob)
311
312 def forward(self, hidden_states, input_tensor):
313 hidden_states = self.dense(hidden_states)
314 hidden_states = self.dropout(hidden_states)
315 hidden_states = self.LayerNorm(hidden_states + input_tensor)
316 return hidden_states
317
318
319 class BertLayer(nn.Module):
320 def __init__(self, config, layer_num):
321 super().__init__()
322 self.config = config
323 self.chunk_size_feed_forward = config.chunk_size_feed_forward
324 self.seq_len_dim = 1
325 self.attention = BertAttention(config)
326 self.layer_num = layer_num
327 if self.config.add_cross_attention:
328 self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
329 self.intermediate = BertIntermediate(config)
330 self.output = BertOutput(config)
331
332 def forward(
333 self,
334 hidden_states,
335 attention_mask=None,
336 head_mask=None,
337 encoder_hidden_states=None,
338 encoder_attention_mask=None,
339 past_key_value=None,
340 output_attentions=False,
341 mode=None,
342 ):
343 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
344 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
345 self_attention_outputs = self.attention(
346 hidden_states,
347 attention_mask,
348 head_mask,
349 output_attentions=output_attentions,
350 past_key_value=self_attn_past_key_value,
351 )
352 attention_output = self_attention_outputs[0]
353
354 outputs = self_attention_outputs[1:-1]
355 present_key_value = self_attention_outputs[-1]
356
357 if mode=='multimodal':
358 assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
359
360 cross_attention_outputs = self.crossattention(
361 attention_output,
362 attention_mask,
363 head_mask,
364 encoder_hidden_states,
365 encoder_attention_mask,
366 output_attentions=output_attentions,
367 )
368 attention_output = cross_attention_outputs[0]
369 outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
370 layer_output = apply_chunking_to_forward(
371 self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
372 )
373 outputs = (layer_output,) + outputs
374
375 outputs = outputs + (present_key_value,)
376
377 return outputs
378
379 def feed_forward_chunk(self, attention_output):
380 intermediate_output = self.intermediate(attention_output)
381 layer_output = self.output(intermediate_output, attention_output)
382 return layer_output
383
384
385 class BertEncoder(nn.Module):
386 def __init__(self, config):
387 super().__init__()
388 self.config = config
389 self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
390 self.gradient_checkpointing = False
391
392 def forward(
393 self,
394 hidden_states,
395 attention_mask=None,
396 head_mask=None,
397 encoder_hidden_states=None,
398 encoder_attention_mask=None,
399 past_key_values=None,
400 use_cache=None,
401 output_attentions=False,
402 output_hidden_states=False,
403 return_dict=True,
404 mode='multimodal',
405 ):
406 all_hidden_states = () if output_hidden_states else None
407 all_self_attentions = () if output_attentions else None
408 all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
409
410 next_decoder_cache = () if use_cache else None
411
412 for i in range(self.config.num_hidden_layers):
413 layer_module = self.layer[i]
414 if output_hidden_states:
415 all_hidden_states = all_hidden_states + (hidden_states,)
416
417 layer_head_mask = head_mask[i] if head_mask is not None else None
418 past_key_value = past_key_values[i] if past_key_values is not None else None
419
420 if self.gradient_checkpointing and self.training:
421
422 if use_cache:
423 logger.warn(
424 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
425 )
426 use_cache = False
427
428 def create_custom_forward(module):
429 def custom_forward(*inputs):
430 return module(*inputs, past_key_value, output_attentions)
431
432 return custom_forward
433
434 layer_outputs = torch.utils.checkpoint.checkpoint(
435 create_custom_forward(layer_module),
436 hidden_states,
437 attention_mask,
438 layer_head_mask,
439 encoder_hidden_states,
440 encoder_attention_mask,
441 mode=mode,
442 )
443 else:
444 layer_outputs = layer_module(
445 hidden_states,
446 attention_mask,
447 layer_head_mask,
448 encoder_hidden_states,
449 encoder_attention_mask,
450 past_key_value,
451 output_attentions,
452 mode=mode,
453 )
454
455 hidden_states = layer_outputs[0]
456 if use_cache:
457 next_decoder_cache += (layer_outputs[-1],)
458 if output_attentions:
459 all_self_attentions = all_self_attentions + (layer_outputs[1],)
460
461 if output_hidden_states:
462 all_hidden_states = all_hidden_states + (hidden_states,)
463
464 if not return_dict:
465 return tuple(
466 v
467 for v in [
468 hidden_states,
469 next_decoder_cache,
470 all_hidden_states,
471 all_self_attentions,
472 all_cross_attentions,
473 ]
474 if v is not None
475 )
476 return BaseModelOutputWithPastAndCrossAttentions(
477 last_hidden_state=hidden_states,
478 past_key_values=next_decoder_cache,
479 hidden_states=all_hidden_states,
480 attentions=all_self_attentions,
481 cross_attentions=all_cross_attentions,
482 )
483
484
485 class BertPooler(nn.Module):
486 def __init__(self, config):
487 super().__init__()
488 self.dense = nn.Linear(config.hidden_size, config.hidden_size)
489 self.activation = nn.Tanh()
490
491 def forward(self, hidden_states):
492 # We "pool" the model by simply taking the hidden state corresponding
493 # to the first token.
494 first_token_tensor = hidden_states[:, 0]
495 pooled_output = self.dense(first_token_tensor)
496 pooled_output = self.activation(pooled_output)
497 return pooled_output
498
499
500 class BertPredictionHeadTransform(nn.Module):
501 def __init__(self, config):
502 super().__init__()
503 self.dense = nn.Linear(config.hidden_size, config.hidden_size)
504 if isinstance(config.hidden_act, str):
505 self.transform_act_fn = ACT2FN[config.hidden_act]
506 else:
507 self.transform_act_fn = config.hidden_act
508 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
509
510 def forward(self, hidden_states):
511 hidden_states = self.dense(hidden_states)
512 hidden_states = self.transform_act_fn(hidden_states)
513 hidden_states = self.LayerNorm(hidden_states)
514 return hidden_states
515
516
517 class BertLMPredictionHead(nn.Module):
518 def __init__(self, config):
519 super().__init__()
520 self.transform = BertPredictionHeadTransform(config)
521
522 # The output weights are the same as the input embeddings, but there is
523 # an output-only bias for each token.
524 self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
525
526 self.bias = nn.Parameter(torch.zeros(config.vocab_size))
527
528 # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
529 self.decoder.bias = self.bias
530
531 def forward(self, hidden_states):
532 hidden_states = self.transform(hidden_states)
533 hidden_states = self.decoder(hidden_states)
534 return hidden_states
535
536
537 class BertOnlyMLMHead(nn.Module):
538 def __init__(self, config):
539 super().__init__()
540 self.predictions = BertLMPredictionHead(config)
541
542 def forward(self, sequence_output):
543 prediction_scores = self.predictions(sequence_output)
544 return prediction_scores
545
546
547 class BertPreTrainedModel(PreTrainedModel):
548 """
549 An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
550 models.
551 """
552
553 config_class = BertConfig
554 base_model_prefix = "bert"
555 _keys_to_ignore_on_load_missing = [r"position_ids"]
556
557 def _init_weights(self, module):
558 """ Initialize the weights """
559 if isinstance(module, (nn.Linear, nn.Embedding)):
560 # Slightly different from the TF version which uses truncated_normal for initialization
561 # cf https://github.com/pytorch/pytorch/pull/5617
562 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
563 elif isinstance(module, nn.LayerNorm):
564 module.bias.data.zero_()
565 module.weight.data.fill_(1.0)
566 if isinstance(module, nn.Linear) and module.bias is not None:
567 module.bias.data.zero_()
568
569
570 class BertModel(BertPreTrainedModel):
571 """
572 The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
573 cross-attention is added between the self-attention layers, following the architecture described in `Attention is
574 all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
575 Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
576 argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
577 input to the forward pass.
578 """
579
580 def __init__(self, config, add_pooling_layer=True):
581 super().__init__(config)
582 self.config = config
583
584 self.embeddings = BertEmbeddings(config)
585
586 self.encoder = BertEncoder(config)
587
588 self.pooler = BertPooler(config) if add_pooling_layer else None
589
590 self.init_weights()
591
592
593 def get_input_embeddings(self):
594 return self.embeddings.word_embeddings
595
596 def set_input_embeddings(self, value):
597 self.embeddings.word_embeddings = value
598
599 def _prune_heads(self, heads_to_prune):
600 """
601 Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
602 class PreTrainedModel
603 """
604 for layer, heads in heads_to_prune.items():
605 self.encoder.layer[layer].attention.prune_heads(heads)
606
607
608 def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
609 """
610 Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
611
612 Arguments:
613 attention_mask (:obj:`torch.Tensor`):
614 Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
615 input_shape (:obj:`Tuple[int]`):
616 The shape of the input to the model.
617 device: (:obj:`torch.device`):
618 The device of the input to the model.
619
620 Returns:
621 :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
622 """
623 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
624 # ourselves in which case we just need to make it broadcastable to all heads.
625 if attention_mask.dim() == 3:
626 extended_attention_mask = attention_mask[:, None, :, :]
627 elif attention_mask.dim() == 2:
628 # Provided a padding mask of dimensions [batch_size, seq_length]
629 # - if the model is a decoder, apply a causal mask in addition to the padding mask
630 # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
631 if is_decoder:
632 batch_size, seq_length = input_shape
633
634 seq_ids = torch.arange(seq_length, device=device)
635 causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
636 # in case past_key_values are used we need to add a prefix ones mask to the causal mask
637 # causal and attention masks must have same type with pytorch version < 1.3
638 causal_mask = causal_mask.to(attention_mask.dtype)
639
640 if causal_mask.shape[1] < attention_mask.shape[1]:
641 prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
642 causal_mask = torch.cat(
643 [
644 torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
645 causal_mask,
646 ],
647 axis=-1,
648 )
649
650 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
651 else:
652 extended_attention_mask = attention_mask[:, None, None, :]
653 else:
654 raise ValueError(
655 "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
656 input_shape, attention_mask.shape
657 )
658 )
659
660 # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
661 # masked positions, this operation will create a tensor which is 0.0 for
662 # positions we want to attend and -10000.0 for masked positions.
663 # Since we are adding it to the raw scores before the softmax, this is
664 # effectively the same as removing these entirely.
665 extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
666 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
667 return extended_attention_mask
668
669 def forward(
670 self,
671 input_ids=None,
672 attention_mask=None,
673 position_ids=None,
674 head_mask=None,
675 inputs_embeds=None,
676 encoder_embeds=None,
677 encoder_hidden_states=None,
678 encoder_attention_mask=None,
679 past_key_values=None,
680 use_cache=None,
681 output_attentions=None,
682 output_hidden_states=None,
683 return_dict=None,
684 is_decoder=False,
685 mode='multimodal',
686 ):
687 r"""
688 encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
689 Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
690 the model is configured as a decoder.
691 encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
692 Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
693 the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
694 - 1 for tokens that are **not masked**,
695 - 0 for tokens that are **masked**.
696 past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
697 Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
698 If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
699 (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
700 instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
701 use_cache (:obj:`bool`, `optional`):
702 If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
703 decoding (see :obj:`past_key_values`).
704 """
705 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
706 output_hidden_states = (
707 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
708 )
709 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
710
711 if is_decoder:
712 use_cache = use_cache if use_cache is not None else self.config.use_cache
713 else:
714 use_cache = False
715
716 if input_ids is not None and inputs_embeds is not None:
717 raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
718 elif input_ids is not None:
719 input_shape = input_ids.size()
720 batch_size, seq_length = input_shape
721 device = input_ids.device
722 elif inputs_embeds is not None:
723 input_shape = inputs_embeds.size()[:-1]
724 batch_size, seq_length = input_shape
725 device = inputs_embeds.device
726 elif encoder_embeds is not None:
727 input_shape = encoder_embeds.size()[:-1]
728 batch_size, seq_length = input_shape
729 device = encoder_embeds.device
730 else:
731 raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
732
733 # past_key_values_length
734 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
735
736 if attention_mask is None:
737 attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
738
739 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
740 # ourselves in which case we just need to make it broadcastable to all heads.
741 extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
742 device, is_decoder)
743
744 # If a 2D or 3D attention mask is provided for the cross-attention
745 # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
746 if encoder_hidden_states is not None:
747 if type(encoder_hidden_states) == list:
748 encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
749 else:
750 encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
751 encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
752
753 if type(encoder_attention_mask) == list:
754 encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
755 elif encoder_attention_mask is None:
756 encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
757 encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
758 else:
759 encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
760 else:
761 encoder_extended_attention_mask = None
762
763 # Prepare head mask if needed
764 # 1.0 in head_mask indicate we keep the head
765 # attention_probs has shape bsz x n_heads x N x N
766 # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
767 # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
768 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
769
770 if encoder_embeds is None:
771 embedding_output = self.embeddings(
772 input_ids=input_ids,
773 position_ids=position_ids,
774 inputs_embeds=inputs_embeds,
775 past_key_values_length=past_key_values_length,
776 )
777 else:
778 embedding_output = encoder_embeds
779
780 encoder_outputs = self.encoder(
781 embedding_output,
782 attention_mask=extended_attention_mask,
783 head_mask=head_mask,
784 encoder_hidden_states=encoder_hidden_states,
785 encoder_attention_mask=encoder_extended_attention_mask,
786 past_key_values=past_key_values,
787 use_cache=use_cache,
788 output_attentions=output_attentions,
789 output_hidden_states=output_hidden_states,
790 return_dict=return_dict,
791 mode=mode,
792 )
793 sequence_output = encoder_outputs[0]
794 pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
795
796 if not return_dict:
797 return (sequence_output, pooled_output) + encoder_outputs[1:]
798
799 return BaseModelOutputWithPoolingAndCrossAttentions(
800 last_hidden_state=sequence_output,
801 pooler_output=pooled_output,
802 past_key_values=encoder_outputs.past_key_values,
803 hidden_states=encoder_outputs.hidden_states,
804 attentions=encoder_outputs.attentions,
805 cross_attentions=encoder_outputs.cross_attentions,
806 )
807
808
809
810 class BertLMHeadModel(BertPreTrainedModel):
811
812 _keys_to_ignore_on_load_unexpected = [r"pooler"]
813 _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
814
815 def __init__(self, config):
816 super().__init__(config)
817
818 self.bert = BertModel(config, add_pooling_layer=False)
819 self.cls = BertOnlyMLMHead(config)
820
821 self.init_weights()
822
823 def get_output_embeddings(self):
824 return self.cls.predictions.decoder
825
826 def set_output_embeddings(self, new_embeddings):
827 self.cls.predictions.decoder = new_embeddings
828
829 def forward(
830 self,
831 input_ids=None,
832 attention_mask=None,
833 position_ids=None,
834 head_mask=None,
835 inputs_embeds=None,
836 encoder_hidden_states=None,
837 encoder_attention_mask=None,
838 labels=None,
839 past_key_values=None,
840 use_cache=None,
841 output_attentions=None,
842 output_hidden_states=None,
843 return_dict=None,
844 return_logits=False,
845 is_decoder=True,
846 reduction='mean',
847 mode='multimodal',
848 ):
849 r"""
850 encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
851 Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
852 the model is configured as a decoder.
853 encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
854 Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
855 the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
856 - 1 for tokens that are **not masked**,
857 - 0 for tokens that are **masked**.
858 labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
859 Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
860 ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
861 ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
862 past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
863 Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
864 If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
865 (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
866 instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
867 use_cache (:obj:`bool`, `optional`):
868 If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
869 decoding (see :obj:`past_key_values`).
870 Returns:
871 Example::
872 >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
873 >>> import torch
874 >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
875 >>> config = BertConfig.from_pretrained("bert-base-cased")
876 >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
877 >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
878 >>> outputs = model(**inputs)
879 >>> prediction_logits = outputs.logits
880 """
881 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
882 if labels is not None:
883 use_cache = False
884
885 outputs = self.bert(
886 input_ids,
887 attention_mask=attention_mask,
888 position_ids=position_ids,
889 head_mask=head_mask,
890 inputs_embeds=inputs_embeds,
891 encoder_hidden_states=encoder_hidden_states,
892 encoder_attention_mask=encoder_attention_mask,
893 past_key_values=past_key_values,
894 use_cache=use_cache,
895 output_attentions=output_attentions,
896 output_hidden_states=output_hidden_states,
897 return_dict=return_dict,
898 is_decoder=is_decoder,
899 mode=mode,
900 )
901
902 sequence_output = outputs[0]
903 prediction_scores = self.cls(sequence_output)
904
905 if return_logits:
906 return prediction_scores[:, :-1, :].contiguous()
907
908 lm_loss = None
909 if labels is not None:
910 # we are doing next-token prediction; shift prediction scores and input ids by one
911 shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
912 labels = labels[:, 1:].contiguous()
913 loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
914 lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
915 if reduction=='none':
916 lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
917
918 if not return_dict:
919 output = (prediction_scores,) + outputs[2:]
920 return ((lm_loss,) + output) if lm_loss is not None else output
921
922 return CausalLMOutputWithCrossAttentions(
923 loss=lm_loss,
924 logits=prediction_scores,
925 past_key_values=outputs.past_key_values,
926 hidden_states=outputs.hidden_states,
927 attentions=outputs.attentions,
928 cross_attentions=outputs.cross_attentions,
929 )
930
931 def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
932 input_shape = input_ids.shape
933 # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
934 if attention_mask is None:
935 attention_mask = input_ids.new_ones(input_shape)
936
937 # cut decoder_input_ids if past is used
938 if past is not None:
939 input_ids = input_ids[:, -1:]
940
941 return {
942 "input_ids": input_ids,
943 "attention_mask": attention_mask,
944 "past_key_values": past,
945 "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
946 "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
947 "is_decoder": True,
948 }
949
950 def _reorder_cache(self, past, beam_idx):
951 reordered_past = ()
952 for layer_past in past:
953 reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
954 return reordered_past