Shortcuts

mmagic.models.archs.lora

Module Contents

Classes

LoRALinear

Linear layer for LoRA.

LoRAWrapper

Wrapper for LoRA layer.

Functions

replace_module(parent_module, child_name, new_module)

Replace module in parent module.

get_submodule(module, key)

Get submodule by key.

set_lora(→ torch.nn.Module)

Set LoRA for module.

set_only_lora_trainable(→ torch.nn.Module)

Set only LoRA modules trainable.

set_lora_enable(→ torch.nn.Module)

Enable LoRA modules.

set_lora_disable(→ torch.nn.Module)

Disable LoRA modules.

class mmagic.models.archs.lora.LoRALinear(in_feat: int, out_feat: int, rank: int = 4)[source]

Bases: torch.nn.Module

Linear layer for LoRA.

Parameters
  • in_feat (int) – Number of input features.

  • out_feat (int) – Number of output features.

  • rank (int) – The rank of LoRA.

forward(x: torch.Tensor) torch.Tensor[source]
class mmagic.models.archs.lora.LoRAWrapper(module: torch.nn.Module, in_feat: int, out_feat: int, rank: int, scale: float = 1, names: Optional[Union[str, List[str]]] = None)[source]

Bases: torch.nn.Module

Wrapper for LoRA layer.

Parameters
  • module (nn.Module) – The module to be wrapped.

  • in_feat (int) – Number of input features.

  • out_feat (int) – Number of output features.

  • rank (int) – The rank of LoRA.

  • scale (float) – The scale of LoRA feature.

  • names (Union[str, List[str]], optional) – The name of LoRA layers. If you want to add multi LoRA for one module, names for each LoRA mapping must be defined.

add_lora(name: str, rank: int, scale: float = 1, state_dict: Optional[dict] = None)[source]

Add LoRA mapping.

Parameters
  • name (str) – The name of added LoRA.

  • rank (int) – The rank of added LoRA.

  • scale (float, optional) – The scale of added LoRA. Defaults to 1.

  • state_dict (dict, optional) – The state dict of added LoRA. Defaults to None.

_set_value(attr_name: str, value: Any, name: Optional[str] = None)[source]

Set value of attribute.

Parameters
  • attr_name (str) – The name of attribute to be set value.

  • value (Any) – The value to be set.

  • name (str, optional) – The name of field in attr_name. If passed, will set value to attr_name[name]. Defaults to None.

set_scale(scale: float, name: Optional[str] = None)[source]

Set LoRA scale.

Parameters
  • scale (float) – The scale to be set.

  • name (str, optional) – The name of LoRA to be set. Defaults to None.

set_enable(name: Optional[str] = None)[source]

Enable LoRA for the current layer.

Parameters

name (str, optional) – The name of LoRA to be set. Defaults to None.

set_disable(name: Optional[str] = None)[source]

Disable LoRA for the current layer.

Parameters

name (str, optional) – The name of LoRA to be set. Defaults to None.

forward_lora_mapping(x: torch.Tensor) torch.Tensor[source]

Forward LoRA mapping.

Parameters

x (Tensor) – The input tensor.

Returns

The output tensor.

Return type

Tensor

forward(x: torch.Tensor, *args, **kwargs) torch.Tensor[source]

Forward and add LoRA mapping.

Parameters

x (Tensor) – The input tensor.

Returns

The output tensor.

Return type

Tensor

classmethod wrap_lora(module, rank=4, scale=1, names=None, state_dict=None)[source]

Wrap LoRA.

Use case: >>> linear = nn.Linear(2, 4) >>> lora_linear = LoRAWrapper.wrap_lora(linear, 4, 1)

Parameters
  • module (nn.Module) – The module to add LoRA.

  • rank (int) – The rank for LoRA.

  • scale (float) –

Return type

LoRAWrapper

mmagic.models.archs.lora.replace_module(parent_module: torch.nn.Module, child_name: str, new_module: torch.nn.Module)[source]

Replace module in parent module.

mmagic.models.archs.lora.get_submodule(module: torch.nn.Module, key: str)[source]

Get submodule by key.

mmagic.models.archs.lora.set_lora(module: torch.nn.Module, config: dict, verbose: bool = True) torch.nn.Module[source]

Set LoRA for module.

Use case: >>> 1. set all lora with same parameters >>> lora_config = dict( >>> rank=4, >>> scale=1, >>> target_modules=[‘to_q’, ‘to_k’, ‘to_v’])

>>> 2. set lora with different parameters
>>> lora_config = dict(
>>>     rank=4,
>>>     scale=1,
>>>     target_modules=[
>>>         # set `to_q` the default parameters
>>>         'to_q',
>>>         # set `to_k` the defined parameters
>>>         dict(target_module='to_k', rank=8, scale=1),
>>>         # set `to_v` the defined `rank` and default `scale`
>>>         dict(target_module='to_v', rank=16)
>>>     ])
Parameters
  • module (nn.Module) – The module to set LoRA.

  • config (dict) – The config dict.

  • verbose (bool) – Whether to print log. Defaults to True.

mmagic.models.archs.lora.set_only_lora_trainable(module: torch.nn.Module) torch.nn.Module[source]

Set only LoRA modules trainable.

mmagic.models.archs.lora.set_lora_enable(module: torch.nn.Module) torch.nn.Module[source]

Enable LoRA modules.

mmagic.models.archs.lora.set_lora_disable(module: torch.nn.Module) torch.nn.Module[source]

Disable LoRA modules.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.