Shortcuts

mmagic.models.editors.dreambooth

Package Contents

Classes

DreamBooth

Implementation of `DreamBooth with Stable Diffusion.

class mmagic.models.editors.dreambooth.DreamBooth(vae: ModelType, text_encoder: ModelType, tokenizer: str, unet: ModelType, scheduler: ModelType, test_scheduler: Optional[ModelType] = None, lora_config: Optional[dict] = None, val_prompts: Union[str, List[str]] = None, class_prior_prompt: Optional[str] = None, num_class_images: Optional[int] = 3, prior_loss_weight: float = 0, finetune_text_encoder: bool = False, dtype: str = 'fp16', enable_xformers: bool = True, noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor: Optional[ModelType] = dict(type='DataPreprocessor'), init_cfg: Optional[dict] = None)[source]

Bases: mmagic.models.editors.stable_diffusion.stable_diffusion.StableDiffusion

Implementation of `DreamBooth with Stable Diffusion.

<https://arxiv.org/abs/2208.12242>`_ (DreamBooth).

Parameters
  • vae (Union[dict, nn.Module]) – The config or module for VAE model.

  • text_encoder (Union[dict, nn.Module]) – The config or module for text encoder.

  • tokenizer (str) – The name for CLIP tokenizer.

  • unet (Union[dict, nn.Module]) – The config or module for Unet model.

  • schedule (Union[dict, nn.Module]) – The config or module for diffusion scheduler.

  • test_scheduler (Union[dict, nn.Module], optional) – The config or module for diffusion scheduler in test stage (self.infer). If not passed, will use the same scheduler as schedule. Defaults to None.

  • lora_config (dict, optional) – The config for LoRA finetuning. Defaults to None.

  • val_prompts (Union[str, List[str]], optional) – The prompts for validation. Defaults to None.

  • class_prior_prompt (str, optional) – The prompt for class prior loss.

  • num_class_images (int, optional) – The number of images for class prior. Defaults to 3.

  • prior_loss_weight (float, optional) – The weight for class prior loss. Defaults to 0.

  • finetune_text_encoder (bool, optional) – Whether to fine-tune text encoder. Defaults to False.

  • dtype (str, optional) – The dtype for the model. Defaults to ‘fp16’.

  • enable_xformers (bool, optional) – Whether to use xformers. Defaults to True.

  • noise_offset_weight (bool, optional) – The weight of noise offset introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa Defaults to 0.

  • tomesd_cfg (dict, optional) – The config for TOMESD. Please refers to https://github.com/dbolya/tomesd and https://github.com/open-mmlab/mmagic/blob/main/mmagic/models/utils/tome_utils.py for detail. # noqa Defaults to None.

  • data_preprocessor (dict, optional) –

    The pre-process config of BaseDataPreprocessor. Defaults to

    dict(type=’DataPreprocessor’).

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Defaults to None/

generate_class_prior_images(num_batches=None)[source]

Generate images for class prior loss.

Parameters

num_batches (int) – Number of batches to generate images. If not passed, all images will be generated in one forward. Defaults to None.

prepare_model()[source]

Prepare model for training.

Move model to target dtype and disable gradient for some models.

set_lora()[source]

Set LORA for model.

val_step(data: dict) mmagic.utils.typing.SampleList[source]

Gets the generated image of given data. Calls self.data_preprocessor and self.infer in order. Return the generated results which will be passed to evaluator or visualizer.

Parameters

data (dict or tuple or list) – Data sampled from dataset.

Returns

Generated image or image dict.

Return type

SampleList

test_step(data: dict) mmagic.utils.typing.SampleList[source]

Gets the generated image of given data. Calls self.data_preprocessor and self.infer in order. Return the generated results which will be passed to evaluator or visualizer.

Parameters

data (dict or tuple or list) – Data sampled from dataset.

Returns

Generated image or image dict.

Return type

SampleList

train_step(data, optim_wrapper)[source]

Implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.

During non-distributed training. If subclasses do not override the train_step(), EpochBasedTrainLoop or IterBasedTrainLoop will call this method to update model parameters. The default parameter update process is as follows:

  1. Calls self.data_processor(data, training=False) to collect batch_inputs and corresponding data_samples(labels).

  2. Calls self(batch_inputs, data_samples, mode='loss') to get raw loss

  3. Calls self.parse_losses to get parsed_losses tensor used to backward and dict of loss tensor used to log messages.

  4. Calls optim_wrapper.update_params(loss) to update model.

Parameters
  • data (dict or tuple or list) – Data sampled from dataset.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

abstract forward(inputs: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'tensor') Union[Dict[str, torch.Tensor], list][source]

forward is not implemented now.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.