mmagic.models.editors.dreambooth
¶
Package Contents¶
Classes¶
Implementation of `DreamBooth with Stable Diffusion. |
- class mmagic.models.editors.dreambooth.DreamBooth(vae: ModelType, text_encoder: ModelType, tokenizer: str, unet: ModelType, scheduler: ModelType, test_scheduler: Optional[ModelType] = None, lora_config: Optional[dict] = None, val_prompts: Union[str, List[str]] = None, class_prior_prompt: Optional[str] = None, num_class_images: Optional[int] = 3, prior_loss_weight: float = 0, finetune_text_encoder: bool = False, dtype: str = 'fp16', enable_xformers: bool = True, noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor: Optional[ModelType] = dict(type='DataPreprocessor'), init_cfg: Optional[dict] = None)[source]¶
Bases:
mmagic.models.editors.stable_diffusion.stable_diffusion.StableDiffusion
Implementation of `DreamBooth with Stable Diffusion.
<https://arxiv.org/abs/2208.12242>`_ (DreamBooth).
- Parameters
vae (Union[dict, nn.Module]) – The config or module for VAE model.
text_encoder (Union[dict, nn.Module]) – The config or module for text encoder.
tokenizer (str) – The name for CLIP tokenizer.
unet (Union[dict, nn.Module]) – The config or module for Unet model.
schedule (Union[dict, nn.Module]) – The config or module for diffusion scheduler.
test_scheduler (Union[dict, nn.Module], optional) – The config or module for diffusion scheduler in test stage (self.infer). If not passed, will use the same scheduler as schedule. Defaults to None.
lora_config (dict, optional) – The config for LoRA finetuning. Defaults to None.
val_prompts (Union[str, List[str]], optional) – The prompts for validation. Defaults to None.
class_prior_prompt (str, optional) – The prompt for class prior loss.
num_class_images (int, optional) – The number of images for class prior. Defaults to 3.
prior_loss_weight (float, optional) – The weight for class prior loss. Defaults to 0.
finetune_text_encoder (bool, optional) – Whether to fine-tune text encoder. Defaults to False.
dtype (str, optional) – The dtype for the model. Defaults to ‘fp16’.
enable_xformers (bool, optional) – Whether to use xformers. Defaults to True.
noise_offset_weight (bool, optional) – The weight of noise offset introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa Defaults to 0.
tomesd_cfg (dict, optional) – The config for TOMESD. Please refers to https://github.com/dbolya/tomesd and https://github.com/open-mmlab/mmagic/blob/main/mmagic/models/utils/tome_utils.py for detail. # noqa Defaults to None.
data_preprocessor (dict, optional) –
The pre-process config of
BaseDataPreprocessor
. Defaults todict(type=’DataPreprocessor’).
init_cfg (dict, optional) – The weight initialized config for
BaseModule
. Defaults to None/
- generate_class_prior_images(num_batches=None)[source]¶
Generate images for class prior loss.
- Parameters
num_batches (int) – Number of batches to generate images. If not passed, all images will be generated in one forward. Defaults to None.
- prepare_model()[source]¶
Prepare model for training.
Move model to target dtype and disable gradient for some models.
- val_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data. Calls
self.data_preprocessor
andself.infer
in order. Return the generated results which will be passed to evaluator or visualizer.- Parameters
data (dict or tuple or list) – Data sampled from dataset.
- Returns
Generated image or image dict.
- Return type
SampleList
- test_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data. Calls
self.data_preprocessor
andself.infer
in order. Return the generated results which will be passed to evaluator or visualizer.- Parameters
data (dict or tuple or list) – Data sampled from dataset.
- Returns
Generated image or image dict.
- Return type
SampleList
- train_step(data, optim_wrapper)[source]¶
Implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.
During non-distributed training. If subclasses do not override the
train_step()
,EpochBasedTrainLoop
orIterBasedTrainLoop
will call this method to update model parameters. The default parameter update process is as follows:Calls
self.data_processor(data, training=False)
to collect batch_inputs and corresponding data_samples(labels).Calls
self(batch_inputs, data_samples, mode='loss')
to get raw lossCalls
self.parse_losses
to getparsed_losses
tensor used to backward and dict of loss tensor used to log messages.Calls
optim_wrapper.update_params(loss)
to update model.
- Parameters
data (dict or tuple or list) – Data sampled from dataset.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]