mmagic.models.archs.lora
¶
Module Contents¶
Classes¶
Linear layer for LoRA. |
|
Wrapper for LoRA layer. |
Functions¶
|
Replace module in parent module. |
|
Get submodule by key. |
|
Set LoRA for module. |
|
Set only LoRA modules trainable. |
|
Enable LoRA modules. |
|
Disable LoRA modules. |
- class mmagic.models.archs.lora.LoRALinear(in_feat: int, out_feat: int, rank: int = 4)[source]¶
Bases:
torch.nn.Module
Linear layer for LoRA.
- Parameters
in_feat (int) – Number of input features.
out_feat (int) – Number of output features.
rank (int) – The rank of LoRA.
- class mmagic.models.archs.lora.LoRAWrapper(module: torch.nn.Module, in_feat: int, out_feat: int, rank: int, scale: float = 1, names: Optional[Union[str, List[str]]] = None)[source]¶
Bases:
torch.nn.Module
Wrapper for LoRA layer.
- Parameters
module (nn.Module) – The module to be wrapped.
in_feat (int) – Number of input features.
out_feat (int) – Number of output features.
rank (int) – The rank of LoRA.
scale (float) – The scale of LoRA feature.
names (Union[str, List[str]], optional) – The name of LoRA layers. If you want to add multi LoRA for one module, names for each LoRA mapping must be defined.
- add_lora(name: str, rank: int, scale: float = 1, state_dict: Optional[dict] = None)[source]¶
Add LoRA mapping.
- Parameters
name (str) – The name of added LoRA.
rank (int) – The rank of added LoRA.
scale (float, optional) – The scale of added LoRA. Defaults to 1.
state_dict (dict, optional) – The state dict of added LoRA. Defaults to None.
- _set_value(attr_name: str, value: Any, name: Optional[str] = None)[source]¶
Set value of attribute.
- Parameters
attr_name (str) – The name of attribute to be set value.
value (Any) – The value to be set.
name (str, optional) – The name of field in attr_name. If passed, will set value to attr_name[name]. Defaults to None.
- set_scale(scale: float, name: Optional[str] = None)[source]¶
Set LoRA scale.
- Parameters
scale (float) – The scale to be set.
name (str, optional) – The name of LoRA to be set. Defaults to None.
- set_enable(name: Optional[str] = None)[source]¶
Enable LoRA for the current layer.
- Parameters
name (str, optional) – The name of LoRA to be set. Defaults to None.
- set_disable(name: Optional[str] = None)[source]¶
Disable LoRA for the current layer.
- Parameters
name (str, optional) – The name of LoRA to be set. Defaults to None.
- forward_lora_mapping(x: torch.Tensor) torch.Tensor [source]¶
Forward LoRA mapping.
- Parameters
x (Tensor) – The input tensor.
- Returns
The output tensor.
- Return type
Tensor
- forward(x: torch.Tensor, *args, **kwargs) torch.Tensor [source]¶
Forward and add LoRA mapping.
- Parameters
x (Tensor) – The input tensor.
- Returns
The output tensor.
- Return type
Tensor
- classmethod wrap_lora(module, rank=4, scale=1, names=None, state_dict=None)[source]¶
Wrap LoRA.
Use case: >>> linear = nn.Linear(2, 4) >>> lora_linear = LoRAWrapper.wrap_lora(linear, 4, 1)
- Parameters
module (nn.Module) – The module to add LoRA.
rank (int) – The rank for LoRA.
scale (float) –
- Return type
- mmagic.models.archs.lora.replace_module(parent_module: torch.nn.Module, child_name: str, new_module: torch.nn.Module)[source]¶
Replace module in parent module.
- mmagic.models.archs.lora.get_submodule(module: torch.nn.Module, key: str)[source]¶
Get submodule by key.
- mmagic.models.archs.lora.set_lora(module: torch.nn.Module, config: dict, verbose: bool = True) torch.nn.Module [source]¶
Set LoRA for module.
Use case: >>> 1. set all lora with same parameters >>> lora_config = dict( >>> rank=4, >>> scale=1, >>> target_modules=[‘to_q’, ‘to_k’, ‘to_v’])
>>> 2. set lora with different parameters >>> lora_config = dict( >>> rank=4, >>> scale=1, >>> target_modules=[ >>> # set `to_q` the default parameters >>> 'to_q', >>> # set `to_k` the defined parameters >>> dict(target_module='to_k', rank=8, scale=1), >>> # set `to_v` the defined `rank` and default `scale` >>> dict(target_module='to_v', rank=16) >>> ])
- Parameters
module (nn.Module) – The module to set LoRA.
config (dict) – The config dict.
verbose (bool) – Whether to print log. Defaults to True.
- mmagic.models.archs.lora.set_only_lora_trainable(module: torch.nn.Module) torch.nn.Module [source]¶
Set only LoRA modules trainable.