Spaces:
Runtime error
Runtime error
import json | |
import os | |
import os.path | |
from abc import ABCMeta | |
from collections import OrderedDict | |
from typing import Any, List, Optional, Union | |
import mmcv | |
import copy | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from mmcv.runner import get_dist_info | |
from .base_dataset import BaseMotionDataset | |
from .builder import DATASETS | |
class TextMotionDataset(BaseMotionDataset): | |
"""TextMotion dataset. | |
Args: | |
text_dir (str): Path to the directory containing the text files. | |
""" | |
def __init__(self, | |
data_prefix: str, | |
pipeline: list, | |
dataset_name: Optional[Union[str, None]] = None, | |
fixed_length: Optional[Union[int, None]] = None, | |
ann_file: Optional[Union[str, None]] = None, | |
motion_dir: Optional[Union[str, None]] = None, | |
text_dir: Optional[Union[str, None]] = None, | |
token_dir: Optional[Union[str, None]] = None, | |
clip_feat_dir: Optional[Union[str, None]] = None, | |
eval_cfg: Optional[Union[dict, None]] = None, | |
fine_mode: Optional[bool] = False, | |
test_mode: Optional[bool] = False): | |
self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir) | |
if token_dir is not None: | |
self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) | |
else: | |
self.token_dir = None | |
if clip_feat_dir is not None: | |
self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) | |
else: | |
self.clip_feat_dir = None | |
self.fine_mode = fine_mode | |
super(TextMotionDataset, self).__init__( | |
data_prefix=data_prefix, | |
pipeline=pipeline, | |
dataset_name=dataset_name, | |
fixed_length=fixed_length, | |
ann_file=ann_file, | |
motion_dir=motion_dir, | |
eval_cfg=eval_cfg, | |
test_mode=test_mode) | |
def load_anno(self, name): | |
results = super().load_anno(name) | |
text_path = os.path.join(self.text_dir, name + '.txt') | |
text_data = [] | |
for line in open(text_path, 'r'): | |
text_data.append(line.strip()) | |
results['text'] = text_data | |
if self.token_dir is not None: | |
token_path = os.path.join(self.token_dir, name + '.txt') | |
token_data = [] | |
for line in open(token_path, 'r'): | |
token_data.append(line.strip()) | |
results['token'] = token_data | |
if self.clip_feat_dir is not None: | |
clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy') | |
clip_feat = torch.from_numpy(np.load(clip_feat_path)) | |
results['clip_feat'] = clip_feat | |
return results | |
def prepare_data(self, idx: int): | |
""""Prepare raw data for the f'{idx'}-th data.""" | |
results = copy.deepcopy(self.data_infos[idx]) | |
text_list = results['text'] | |
idx = np.random.randint(0, len(text_list)) | |
if self.fine_mode: | |
results['text'] = json.loads(text_list[idx]) | |
else: | |
results['text'] = text_list[idx] | |
if 'clip_feat' in results.keys(): | |
results['clip_feat'] = results['clip_feat'][idx] | |
if 'token' in results.keys(): | |
results['token'] = results['token'][idx] | |
results['dataset_name'] = self.dataset_name | |
results['sample_idx'] = idx | |
return self.pipeline(results) | |