feat: formatting and type hints
Browse files- modeling_lora.py +122 -35
modeling_lora.py
CHANGED
|
@@ -1,23 +1,24 @@
|
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
-
from typing import Iterator, Tuple
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
import torch.nn.utils.parametrize as parametrize
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
from torch.nn import Parameter
|
| 10 |
|
| 11 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 12 |
|
| 13 |
|
| 14 |
-
def initialized_weights(
|
|
|
|
|
|
|
| 15 |
weight_data = []
|
| 16 |
for _ in range(num_adaptions):
|
| 17 |
new_adaption = torch.zeros(shape)
|
| 18 |
-
if init ==
|
| 19 |
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
|
| 20 |
-
elif init ==
|
| 21 |
nn.init.normal_(new_adaption)
|
| 22 |
else:
|
| 23 |
raise NotImplementedError
|
|
@@ -26,27 +27,48 @@ def initialized_weights(shape, num_adaptions, init='kaiming'):
|
|
| 26 |
|
| 27 |
|
| 28 |
class LoRAParametrization(nn.Module):
|
| 29 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
super().__init__()
|
| 31 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
| 32 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
| 33 |
-
fan_in_fan_out =
|
| 34 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 35 |
|
| 36 |
-
if layer_type ==
|
| 37 |
-
self.lora_A = nn.Parameter(
|
|
|
|
|
|
|
| 38 |
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
|
| 39 |
-
elif layer_type ==
|
| 40 |
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
|
| 41 |
-
self.lora_B = nn.Parameter(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
else:
|
| 43 |
raise NotImplementedError
|
| 44 |
|
| 45 |
self.lora_alpha, self.rank = lora_alpha, rank
|
| 46 |
self.scaling = lora_alpha / rank
|
| 47 |
-
self.lora_dropout =
|
|
|
|
|
|
|
| 48 |
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
|
| 49 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.forward_fn = lambda x: x
|
| 51 |
self.current_task = None
|
| 52 |
|
|
@@ -56,7 +78,18 @@ class LoRAParametrization(nn.Module):
|
|
| 56 |
|
| 57 |
def lora_forward(self, X):
|
| 58 |
assert self.current_task is not None
|
| 59 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def forward(self, X):
|
| 62 |
return self.forward_fn(X)
|
|
@@ -69,28 +102,73 @@ class LoRAParametrization(nn.Module):
|
|
| 69 |
self.forward_fn = self.lora_forward
|
| 70 |
|
| 71 |
@classmethod
|
| 72 |
-
def from_linear(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
fan_out, fan_in = layer.weight.shape
|
| 74 |
return cls(
|
| 75 |
-
fan_in,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
@classmethod
|
| 79 |
-
def from_embedding(
|
|
|
|
|
|
|
|
|
|
| 80 |
fan_in, fan_out = layer.weight.shape
|
| 81 |
return cls(
|
| 82 |
-
fan_in,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
@classmethod
|
| 86 |
-
def add_to_layer(
|
|
|
|
|
|
|
| 87 |
if isinstance(layer, nn.Linear):
|
| 88 |
-
parametrize.register_parametrization(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
elif isinstance(layer, nn.Embedding):
|
| 90 |
-
parametrize.register_parametrization(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
@classmethod
|
| 93 |
-
def select_task_for_layer(cls, layer, task_idx=None):
|
| 94 |
if isinstance(layer, LoRAParametrization):
|
| 95 |
layer.select_task(task_idx)
|
| 96 |
|
|
@@ -101,7 +179,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 101 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
| 102 |
self._register_lora(num_adaptions)
|
| 103 |
for name, param in super().named_parameters():
|
| 104 |
-
if
|
| 105 |
param.requires_grad_(False)
|
| 106 |
|
| 107 |
def from_bert(self, *args, num_adaptions=1, **kwargs):
|
|
@@ -109,10 +187,20 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 109 |
self._register_lora(num_adaptions)
|
| 110 |
|
| 111 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 112 |
-
self.apply(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
def select_task(self, task_idx):
|
| 115 |
-
self.apply(
|
|
|
|
|
|
|
| 116 |
|
| 117 |
def forward(self, *args, **kwargs):
|
| 118 |
return self.bert(*args, **kwargs)
|
|
@@ -122,11 +210,10 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 122 |
yield param
|
| 123 |
|
| 124 |
def named_parameters(
|
| 125 |
-
|
| 126 |
-
prefix: str = '',
|
| 127 |
-
recurse: bool = True,
|
| 128 |
-
remove_duplicate: bool = True
|
| 129 |
) -> Iterator[Tuple[str, Parameter]]:
|
| 130 |
-
for name, param in super().named_parameters(
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
from functools import partial
|
| 3 |
+
from typing import Iterator, Optional, Tuple, Union
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
import torch.nn.utils.parametrize as parametrize
|
| 7 |
+
from torch import nn
|
|
|
|
| 8 |
from torch.nn import Parameter
|
| 9 |
|
| 10 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 11 |
|
| 12 |
|
| 13 |
+
def initialized_weights(
|
| 14 |
+
shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
weight_data = []
|
| 17 |
for _ in range(num_adaptions):
|
| 18 |
new_adaption = torch.zeros(shape)
|
| 19 |
+
if init == "kaiming":
|
| 20 |
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
|
| 21 |
+
elif init == "normal":
|
| 22 |
nn.init.normal_(new_adaption)
|
| 23 |
else:
|
| 24 |
raise NotImplementedError
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class LoRAParametrization(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
fan_in: int,
|
| 33 |
+
fan_out: int,
|
| 34 |
+
layer_type: str = "linear",
|
| 35 |
+
num_adaptions: int = 1,
|
| 36 |
+
rank: int = 4,
|
| 37 |
+
lora_dropout_p: float = 0.0,
|
| 38 |
+
lora_alpha: float = 1,
|
| 39 |
+
):
|
| 40 |
super().__init__()
|
| 41 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
| 42 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
| 43 |
+
fan_in_fan_out = layer_type == "embedding"
|
| 44 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 45 |
|
| 46 |
+
if layer_type == "linear":
|
| 47 |
+
self.lora_A = nn.Parameter(
|
| 48 |
+
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
| 49 |
+
)
|
| 50 |
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
|
| 51 |
+
elif layer_type == "embedding":
|
| 52 |
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
|
| 53 |
+
self.lora_B = nn.Parameter(
|
| 54 |
+
initialized_weights(
|
| 55 |
+
(rank, fan_out), num_adaptions=num_adaptions, init="normal"
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
else:
|
| 59 |
raise NotImplementedError
|
| 60 |
|
| 61 |
self.lora_alpha, self.rank = lora_alpha, rank
|
| 62 |
self.scaling = lora_alpha / rank
|
| 63 |
+
self.lora_dropout = (
|
| 64 |
+
nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
|
| 65 |
+
)
|
| 66 |
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
|
| 67 |
+
self.register_buffer(
|
| 68 |
+
"lora_dropout_mask",
|
| 69 |
+
torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
|
| 70 |
+
persistent=False,
|
| 71 |
+
)
|
| 72 |
self.forward_fn = lambda x: x
|
| 73 |
self.current_task = None
|
| 74 |
|
|
|
|
| 78 |
|
| 79 |
def lora_forward(self, X):
|
| 80 |
assert self.current_task is not None
|
| 81 |
+
return (
|
| 82 |
+
X
|
| 83 |
+
+ torch.matmul(
|
| 84 |
+
*self.swap(
|
| 85 |
+
(
|
| 86 |
+
self.lora_B[self.current_task],
|
| 87 |
+
self.dropout_fn(self.lora_A[self.current_task]),
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
).view(X.shape)
|
| 91 |
+
* self.scaling
|
| 92 |
+
)
|
| 93 |
|
| 94 |
def forward(self, X):
|
| 95 |
return self.forward_fn(X)
|
|
|
|
| 102 |
self.forward_fn = self.lora_forward
|
| 103 |
|
| 104 |
@classmethod
|
| 105 |
+
def from_linear(
|
| 106 |
+
cls,
|
| 107 |
+
layer: nn.Module,
|
| 108 |
+
num_adaptions: int = 1,
|
| 109 |
+
rank: int = 4,
|
| 110 |
+
lora_dropout_p: float = 0.0,
|
| 111 |
+
lora_alpha: int = 1,
|
| 112 |
+
):
|
| 113 |
+
assert isinstance(layer, nn.Linear)
|
| 114 |
fan_out, fan_in = layer.weight.shape
|
| 115 |
return cls(
|
| 116 |
+
fan_in,
|
| 117 |
+
fan_out,
|
| 118 |
+
num_adaptions=num_adaptions,
|
| 119 |
+
layer_type="linear",
|
| 120 |
+
rank=rank,
|
| 121 |
+
lora_dropout_p=lora_dropout_p,
|
| 122 |
+
lora_alpha=lora_alpha,
|
| 123 |
)
|
| 124 |
|
| 125 |
@classmethod
|
| 126 |
+
def from_embedding(
|
| 127 |
+
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
|
| 128 |
+
):
|
| 129 |
+
assert isinstance(layer, nn.Embedding)
|
| 130 |
fan_in, fan_out = layer.weight.shape
|
| 131 |
return cls(
|
| 132 |
+
fan_in,
|
| 133 |
+
fan_out,
|
| 134 |
+
num_adaptions=num_adaptions,
|
| 135 |
+
layer_type="embedding",
|
| 136 |
+
rank=rank,
|
| 137 |
+
lora_dropout_p=lora_dropout_p,
|
| 138 |
+
lora_alpha=lora_alpha,
|
| 139 |
)
|
| 140 |
|
| 141 |
@classmethod
|
| 142 |
+
def add_to_layer(
|
| 143 |
+
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
|
| 144 |
+
):
|
| 145 |
if isinstance(layer, nn.Linear):
|
| 146 |
+
parametrize.register_parametrization(
|
| 147 |
+
layer,
|
| 148 |
+
"weight",
|
| 149 |
+
cls.from_linear(
|
| 150 |
+
layer,
|
| 151 |
+
num_adaptions=num_adaptions,
|
| 152 |
+
rank=rank,
|
| 153 |
+
lora_dropout_p=lora_dropout_p,
|
| 154 |
+
lora_alpha=lora_alpha,
|
| 155 |
+
),
|
| 156 |
+
)
|
| 157 |
elif isinstance(layer, nn.Embedding):
|
| 158 |
+
parametrize.register_parametrization(
|
| 159 |
+
layer,
|
| 160 |
+
"weight",
|
| 161 |
+
cls.from_embedding(
|
| 162 |
+
layer,
|
| 163 |
+
num_adaptions=num_adaptions,
|
| 164 |
+
rank=rank,
|
| 165 |
+
lora_dropout_p=lora_dropout_p,
|
| 166 |
+
lora_alpha=lora_alpha,
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
|
| 170 |
@classmethod
|
| 171 |
+
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
| 172 |
if isinstance(layer, LoRAParametrization):
|
| 173 |
layer.select_task(task_idx)
|
| 174 |
|
|
|
|
| 179 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
| 180 |
self._register_lora(num_adaptions)
|
| 181 |
for name, param in super().named_parameters():
|
| 182 |
+
if "lora" not in name:
|
| 183 |
param.requires_grad_(False)
|
| 184 |
|
| 185 |
def from_bert(self, *args, num_adaptions=1, **kwargs):
|
|
|
|
| 187 |
self._register_lora(num_adaptions)
|
| 188 |
|
| 189 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 190 |
+
self.apply(
|
| 191 |
+
partial(
|
| 192 |
+
LoRAParametrization.add_to_layer,
|
| 193 |
+
num_adaptions=num_adaptions,
|
| 194 |
+
rank=rank,
|
| 195 |
+
lora_dropout_p=lora_dropout_p,
|
| 196 |
+
lora_alpha=lora_alpha,
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
|
| 200 |
+
def select_task(self, task_idx: Union[None, int]):
|
| 201 |
+
self.apply(
|
| 202 |
+
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
| 203 |
+
)
|
| 204 |
|
| 205 |
def forward(self, *args, **kwargs):
|
| 206 |
return self.bert(*args, **kwargs)
|
|
|
|
| 210 |
yield param
|
| 211 |
|
| 212 |
def named_parameters(
|
| 213 |
+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
|
|
|
|
|
|
|
|
|
| 214 |
) -> Iterator[Tuple[str, Parameter]]:
|
| 215 |
+
for name, param in super().named_parameters(
|
| 216 |
+
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
| 217 |
+
):
|
| 218 |
+
if "lora" in name:
|
| 219 |
+
yield name, param
|