Shortcuts

Source code for mmagic.datasets.textual_inversion_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from random import choice
from typing import Callable, List, Union

from mmengine import FileClient
from mmengine.dataset import BaseDataset

from mmagic.registry import DATASETS


@DATASETS.register_module()
[docs]class TextualInversionDataset(BaseDataset): """Dataset for Textual Inversion and ViCo. Args: data_root (str): Path to the data root. concept_dir (str): Path to the concept images. placeholder (str): A string to denote the concept. template (list[str]): A list of strings like 'A photo of {}'. with_image_reference (bool): Is used for vico training. pipeline (list[dict | callable]): A sequence of data transforms. """ def __init__( self, data_root: str, concept_dir: str, placeholder: str, template: str, # used for vico training with_image_reference: bool = False, pipeline: List[Union[dict, Callable]] = []): data_prefix = dict(img_path=concept_dir) self.placeholder = placeholder if osp.exists(osp.join(data_root, concept_dir)): self.num_images = len(os.listdir(osp.join(data_root, concept_dir))) if osp.exists(template): with open(template, 'r') as file: self.template = file.readlines() self.with_image_reference = with_image_reference super().__init__( data_root=data_root, data_prefix=data_prefix, pipeline=pipeline)
[docs] def load_data_list(self) -> list: """Load data list from concept_dir and class_dir.""" data_list = [] img_dir = self.data_prefix['img_path'] file_client = FileClient.infer_client(uri=img_dir) img_dir = osp.abspath(img_dir) for data_name in file_client.list_dir_or_file(img_dir, list_dir=False): data_info = dict( img_path=file_client.join_path(img_dir, data_name)) data_list.append(data_info) return data_list
[docs] def prepare_data(self, idx): """Get data processed by ``self.pipeline``. Args: idx (int): The index of ``data_info``. Returns: Any: Depends on ``self.pipeline``. """ data_info = self.get_data_info(idx) if self.with_image_reference: numbers = list(range(self.num_images)) if len(numbers) > 1: numbers.remove(idx % self.num_images) img_dir = self.data_prefix['img_path'] file_client = FileClient.infer_client(uri=img_dir) img_dir = osp.abspath(img_dir) data_names = list( file_client.list_dir_or_file(img_dir, list_dir=False)) image_ref_path = file_client.join_path(img_dir, data_names[choice(numbers)]) data_info['img_ref_path'] = image_ref_path # load random template selected_template = choice(self.template) prompt = selected_template.format(self.placeholder) data_info['prompt'] = prompt return self.pipeline(data_info)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.