initial
Browse files- bvh2ts.py +69 -0
- checkpoint/UniMTS.pth +3 -0
- contrastive.py +62 -0
- data.py +321 -0
- evaluate.py +99 -0
- evaluate_custom.py +101 -0
- finetune.py +169 -0
- finetune_custom.py +172 -0
- model.py +350 -0
- pos2bvh.py +41 -0
- pretrain.py +111 -0
- run_evaluation.sh +4 -0
- run_evaluation_custom.sh +8 -0
- run_finetune.sh +19 -0
- run_finetune_custom.sh +33 -0
- run_pretrain.sh +4 -0
- text_aug.py +66 -0
- utils.py +215 -0
bvh2ts.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from imusim.all import *
|
| 2 |
+
import imusim
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import multiprocessing
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
with open('./bvh/000000.bvh', 'r') as file:
|
| 9 |
+
lines = file.readlines()
|
| 10 |
+
line_109 = lines[108]
|
| 11 |
+
frame_time = line_109.split(': ')[1].strip()
|
| 12 |
+
frame_time_value = float(frame_time)
|
| 13 |
+
print(frame_time_value)
|
| 14 |
+
|
| 15 |
+
def process_file(f):
|
| 16 |
+
|
| 17 |
+
imu_file_path = './output/%s.npy' % f
|
| 18 |
+
if not os.path.exists(imu_file_path):
|
| 19 |
+
|
| 20 |
+
samplingPeriod = frame_time_value
|
| 21 |
+
imu = Orient3IMU()
|
| 22 |
+
env = Environment()
|
| 23 |
+
|
| 24 |
+
samples = 1000
|
| 25 |
+
rotationalVelocity = 20
|
| 26 |
+
calibrator = ScaleAndOffsetCalibrator(env, samples, samplingPeriod, rotationalVelocity)
|
| 27 |
+
calibration = calibrator.calibrate(imu)
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
model = loadBVHFile('./bvh/%s.bvh' % f)
|
| 31 |
+
splinedModel = SplinedBodyModel(model)
|
| 32 |
+
|
| 33 |
+
imu_list = []
|
| 34 |
+
for i in range(22):
|
| 35 |
+
sim = Simulation(environment=env)
|
| 36 |
+
imu.simulation = sim
|
| 37 |
+
|
| 38 |
+
if i not in [4,8,13,17,21]:
|
| 39 |
+
imu.trajectory = splinedModel.getJoint('joint_%s' % str(i))
|
| 40 |
+
else:
|
| 41 |
+
imu.trajectory = splinedModel.getPoint('joint_%s_end' % str(i-1))
|
| 42 |
+
|
| 43 |
+
sim.time = splinedModel.startTime
|
| 44 |
+
BasicIMUBehaviour(imu, samplingPeriod, calibration, initialTime=sim.time)
|
| 45 |
+
sim.run(splinedModel.endTime, printProgress=False)
|
| 46 |
+
|
| 47 |
+
acc = imu.accelerometer.calibratedMeasurements.values
|
| 48 |
+
gyro = imu.gyroscope.calibratedMeasurements.values
|
| 49 |
+
|
| 50 |
+
imu_npy = np.concatenate((acc, gyro), axis=0)
|
| 51 |
+
imu_list.append(imu_npy)
|
| 52 |
+
|
| 53 |
+
imu_npy = np.stack(imu_list, axis=1).transpose(2,1,0)
|
| 54 |
+
np.save('./output/%s' % f, imu_npy)
|
| 55 |
+
|
| 56 |
+
except (imusim.maths.splines.Spline.InsufficientPointsError, AttributeError, IndexError) as e:
|
| 57 |
+
print(f"Error processing file {f}: {e}. Skipping.")
|
| 58 |
+
with open('log.txt', 'a') as log_file:
|
| 59 |
+
log_file.write(f + '\n')
|
| 60 |
+
|
| 61 |
+
source_dir = './bvh'
|
| 62 |
+
npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.bvh')]
|
| 63 |
+
|
| 64 |
+
# Process files in parallel
|
| 65 |
+
pool = multiprocessing.Pool(processes=8)
|
| 66 |
+
for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
|
| 67 |
+
pass
|
| 68 |
+
pool.close()
|
| 69 |
+
pool.join()
|
checkpoint/UniMTS.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9858c0084d936655240407e30ff9db9adeded6a67dc5650e3f667578e93b220
|
| 3 |
+
size 274583082
|
contrastive.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import clip
|
| 4 |
+
from model import ST_GCN_18
|
| 5 |
+
|
| 6 |
+
class ContrastiveModule(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, args):
|
| 9 |
+
super(ContrastiveModule, self).__init__()
|
| 10 |
+
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
| 13 |
+
del model.visual
|
| 14 |
+
self.model = model
|
| 15 |
+
|
| 16 |
+
base_channel = 3
|
| 17 |
+
base_channel = base_channel * 2 if args.gyro else base_channel
|
| 18 |
+
base_channel = base_channel * 2 if args.stft else base_channel
|
| 19 |
+
self.model.acc = ST_GCN_18(in_channels=base_channel)
|
| 20 |
+
|
| 21 |
+
self.model = self.model.float()
|
| 22 |
+
|
| 23 |
+
if args.stage == 'finetune':
|
| 24 |
+
self.fc = nn.Linear(512, args.num_class)
|
| 25 |
+
|
| 26 |
+
def encode_image(self, image):
|
| 27 |
+
return self.model.acc(image.float()).squeeze(-1).squeeze(-1)
|
| 28 |
+
|
| 29 |
+
def encode_text(self, text):
|
| 30 |
+
x = self.model.token_embedding(text).float() # b,t,512
|
| 31 |
+
x = x + self.model.positional_embedding.float()
|
| 32 |
+
x = x.permute(1, 0, 2) # b,t,512 -> t,b,512
|
| 33 |
+
x = self.model.transformer(x)
|
| 34 |
+
x = x.permute(1, 0, 2) # t,b,512 -> b,t,512
|
| 35 |
+
x = self.model.ln_final(x).float() # b,t,512
|
| 36 |
+
|
| 37 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 38 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection # b,512
|
| 39 |
+
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def classifier(self, image):
|
| 43 |
+
# for fine-tuning
|
| 44 |
+
imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1)
|
| 45 |
+
out = self.fc(imu_features)
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
def forward(self, inputs_imu, inputs_text):
|
| 49 |
+
|
| 50 |
+
imu_features = self.encode_image(inputs_imu)
|
| 51 |
+
text_features = self.encode_text(inputs_text)
|
| 52 |
+
|
| 53 |
+
# normalized features
|
| 54 |
+
imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True)
|
| 55 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 56 |
+
|
| 57 |
+
# logits
|
| 58 |
+
logit_scale = self.model.logit_scale.exp()
|
| 59 |
+
logits_per_image = logit_scale * imu_features @ text_features.t()
|
| 60 |
+
logits_per_text = logits_per_image.t()
|
| 61 |
+
|
| 62 |
+
return logits_per_image, logits_per_text
|
data.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
from scipy.signal import resample
|
| 7 |
+
import clip
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
class CLIPDataset(Dataset):
|
| 11 |
+
|
| 12 |
+
def __init__(self, args):
|
| 13 |
+
|
| 14 |
+
imu_dirs = [
|
| 15 |
+
f'{args.data_path}/sim/',
|
| 16 |
+
]
|
| 17 |
+
text_dirs = [
|
| 18 |
+
f'{args.data_path}/aug_texts/',
|
| 19 |
+
]
|
| 20 |
+
self.paths = []
|
| 21 |
+
for imu_dir, text_dir in zip(imu_dirs, text_dirs):
|
| 22 |
+
imu_files = [f.split('.')[0] for f in os.listdir(imu_dir) if os.path.isfile(os.path.join(imu_dir, f))]
|
| 23 |
+
text_files = [f.split('.')[0] for f in os.listdir(text_dir) if os.path.isfile(os.path.join(text_dir, f))]
|
| 24 |
+
common_files = [f for f in imu_files if f in text_files]
|
| 25 |
+
for f in common_files:
|
| 26 |
+
self.paths.append((os.path.join(imu_dir, f + '.npy'), os.path.join(text_dir, f + '.txt')))
|
| 27 |
+
|
| 28 |
+
self.args = args
|
| 29 |
+
if args.sample < 1:
|
| 30 |
+
self.paths = random.sample(self.paths, int(len(self.paths) * args.sample))
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.paths)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
|
| 37 |
+
# load imu
|
| 38 |
+
imu_path, text_path = self.paths[idx]
|
| 39 |
+
imu = np.load(imu_path)
|
| 40 |
+
imu[np.isnan(imu)] = 0
|
| 41 |
+
|
| 42 |
+
# padding
|
| 43 |
+
if len(imu) < self.args.padding_size:
|
| 44 |
+
imu = np.pad(imu, ((0, self.args.padding_size - len(imu)), (0, 0), (0, 0)), mode='wrap')
|
| 45 |
+
imu = imu[:self.args.padding_size]
|
| 46 |
+
|
| 47 |
+
# random masking
|
| 48 |
+
mask = np.zeros_like(imu)
|
| 49 |
+
k = np.random.randint(1, 6) # randomly select k joints
|
| 50 |
+
selected_joints = np.random.choice(22, k, replace=False)
|
| 51 |
+
mask[:,selected_joints] = 1
|
| 52 |
+
imu = imu.reshape(len(imu), -1)
|
| 53 |
+
mask = mask.reshape(len(mask), -1)
|
| 54 |
+
|
| 55 |
+
# load text
|
| 56 |
+
with open(text_path, 'r') as file:
|
| 57 |
+
lines = file.readlines()
|
| 58 |
+
|
| 59 |
+
text = random.choice(lines).split('#')[0].strip() # remove the comment starting from "#"
|
| 60 |
+
|
| 61 |
+
batch = {}
|
| 62 |
+
batch['imu'] = imu
|
| 63 |
+
batch['text'] = text
|
| 64 |
+
batch['mask'] = mask
|
| 65 |
+
|
| 66 |
+
return batch
|
| 67 |
+
|
| 68 |
+
def select_samples(data, masks, labels, k, name, data_path):
|
| 69 |
+
unique_labels = torch.unique(labels)
|
| 70 |
+
selected_data = []
|
| 71 |
+
selected_masks = []
|
| 72 |
+
selected_labels = []
|
| 73 |
+
all_indices = torch.load(f'{data_path}/few_shot_data_2/{name}_k={k}.pth')
|
| 74 |
+
|
| 75 |
+
for i, label in enumerate(unique_labels):
|
| 76 |
+
selected_indices = all_indices[i]
|
| 77 |
+
selected_data.append(data[selected_indices])
|
| 78 |
+
selected_masks.append(masks[selected_indices])
|
| 79 |
+
selected_labels.append(labels[selected_indices])
|
| 80 |
+
|
| 81 |
+
selected_data = torch.cat(selected_data, dim=0)
|
| 82 |
+
selected_masks = torch.cat(selected_masks, dim=0)
|
| 83 |
+
selected_labels = torch.cat(selected_labels, dim=0)
|
| 84 |
+
|
| 85 |
+
return selected_data, selected_masks, selected_labels
|
| 86 |
+
|
| 87 |
+
def load(dataset, padding_size, data_path, split='test', k=None):
|
| 88 |
+
|
| 89 |
+
print(dataset)
|
| 90 |
+
|
| 91 |
+
X = np.load(f'{data_path}/{dataset}/X_{split}.npy')
|
| 92 |
+
real_labels = torch.from_numpy(np.load(f'{data_path}/{dataset}/y_{split}.npy'))
|
| 93 |
+
with open(f'{data_path}/{dataset}/{dataset}.json', 'r') as file:
|
| 94 |
+
data = json.load(file)
|
| 95 |
+
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
|
| 96 |
+
|
| 97 |
+
if dataset == 'PAMAP':
|
| 98 |
+
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
| 99 |
+
all_X[:,:,11] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
| 100 |
+
all_X[:,:,7] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
|
| 101 |
+
original_sampling_rate = 100
|
| 102 |
+
num_classes = 12
|
| 103 |
+
|
| 104 |
+
elif dataset == 'USCHAD':
|
| 105 |
+
all_X[:,:,5] = np.concatenate((X[:,:,0:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
|
| 106 |
+
original_sampling_rate = 100
|
| 107 |
+
num_classes = 12
|
| 108 |
+
|
| 109 |
+
elif dataset == 'UCIHAR':
|
| 110 |
+
all_X[:,:,9] = np.concatenate((X[:,:,6:9] * 9.80665, X[:,:,3:6]), axis=-1) # linear accel, gyro, total accel
|
| 111 |
+
original_sampling_rate = 50
|
| 112 |
+
num_classes = 6
|
| 113 |
+
|
| 114 |
+
elif dataset == 'Opp_g':
|
| 115 |
+
all_X[:,:,10] = np.concatenate((X[:,:,0:3] / 1000 * 9.8, X[:,:,3:6] / 1000), axis=-1) # convert unit from milli g to m/s^2
|
| 116 |
+
all_X[:,:,19] = np.concatenate((X[:,:,9:12] / 1000 * 9.8, X[:,:,12:15] / 1000), axis=-1)
|
| 117 |
+
all_X[:,:,20] = np.concatenate((X[:,:,18:21] / 1000 * 9.8, X[:,:,21:24] / 1000), axis=-1)
|
| 118 |
+
all_X[:,:,15] = np.concatenate((X[:,:,27:30] / 1000 * 9.8, X[:,:,30:33] / 1000), axis=-1)
|
| 119 |
+
all_X[:,:,16] = np.concatenate((X[:,:,36:39] / 1000 * 9.8, X[:,:,39:42] / 1000), axis=-1)
|
| 120 |
+
original_sampling_rate = 30
|
| 121 |
+
num_classes = 4 # locomotion
|
| 122 |
+
|
| 123 |
+
elif dataset == 'WISDM':
|
| 124 |
+
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
| 125 |
+
original_sampling_rate = 20
|
| 126 |
+
num_classes = 18
|
| 127 |
+
|
| 128 |
+
elif dataset == 'DSADS':
|
| 129 |
+
all_X[:,:,11] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
| 130 |
+
all_X[:,:,21] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
|
| 131 |
+
all_X[:,:,17] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
| 132 |
+
all_X[:,:,6] = np.concatenate((X[:,:,27:30], X[:,:,30:33]), axis=-1)
|
| 133 |
+
all_X[:,:,2] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
|
| 134 |
+
original_sampling_rate = 25
|
| 135 |
+
num_classes = 19
|
| 136 |
+
|
| 137 |
+
elif dataset == 'Harth':
|
| 138 |
+
all_X[:,:,9,:3] = X[:,:,:3] * 9.80665
|
| 139 |
+
all_X[:,:,6,:3] = X[:,:,3:6] * 9.80665
|
| 140 |
+
original_sampling_rate = 50
|
| 141 |
+
num_classes = 12
|
| 142 |
+
|
| 143 |
+
elif dataset == 'Wharf':
|
| 144 |
+
X = -14.709 + X / 63 * (2 * 14.709)
|
| 145 |
+
all_X[:,:,21,:3] = X
|
| 146 |
+
original_sampling_rate = 32
|
| 147 |
+
num_classes = 14
|
| 148 |
+
|
| 149 |
+
elif dataset == 'Mhealth':
|
| 150 |
+
all_X[:,:,11,:3] = X[:,:,0:3]
|
| 151 |
+
all_X[:,:,3] = np.concatenate((X[:,:,6:9], X[:,:,9:12] / 180 * np.pi), axis=-1)
|
| 152 |
+
all_X[:,:,21] = np.concatenate((X[:,:,15:18], X[:,:,18:21] / 180 * np.pi), axis=-1)
|
| 153 |
+
original_sampling_rate = 50
|
| 154 |
+
num_classes = 12
|
| 155 |
+
|
| 156 |
+
elif dataset == 'UTD-MHAD':
|
| 157 |
+
all_X[real_labels < 21,:,21,:] = np.concatenate((X[real_labels < 21,:,0:3] * 9.80665, X[real_labels < 21,:,3:6] / 180 * np.pi), axis=-1)
|
| 158 |
+
all_X[real_labels >= 21,:,5,:] = np.concatenate((X[real_labels >= 21,:,0:3] * 9.80665, X[real_labels >= 21,:,3:6] / 180 * np.pi), axis=-1)
|
| 159 |
+
original_sampling_rate = 50
|
| 160 |
+
num_classes = 27
|
| 161 |
+
|
| 162 |
+
elif dataset == 'MotionSense':
|
| 163 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
|
| 164 |
+
all_X[:,:,1] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
|
| 165 |
+
original_sampling_rate = 50
|
| 166 |
+
num_classes = 6
|
| 167 |
+
|
| 168 |
+
elif dataset == 'w-HAR':
|
| 169 |
+
all_X[:,:,7] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
|
| 170 |
+
original_sampling_rate = 250
|
| 171 |
+
num_classes = 7
|
| 172 |
+
|
| 173 |
+
elif dataset == 'Shoaib':
|
| 174 |
+
all_X[:,:,1] = X[:,:,:6]
|
| 175 |
+
all_X[:,:,5] = X[:,:,6:12]
|
| 176 |
+
all_X[:,:,21] = X[:,:,12:18]
|
| 177 |
+
all_X[:,:,20] = X[:,:,18:24]
|
| 178 |
+
all_X[:,:,0] = X[:,:,24:30]
|
| 179 |
+
original_sampling_rate = 50
|
| 180 |
+
num_classes = 7
|
| 181 |
+
|
| 182 |
+
elif dataset == 'har70plus':
|
| 183 |
+
all_X[:,:,0,:3] = X[:,:,:3] * 9.80665
|
| 184 |
+
all_X[:,:,5,:3] = X[:,:,3:6] * 9.80665
|
| 185 |
+
original_sampling_rate = 50
|
| 186 |
+
num_classes = 7
|
| 187 |
+
|
| 188 |
+
elif dataset == 'MMAct':
|
| 189 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
| 190 |
+
all_X[:,:,21,:3] = X[:,:,6:9]
|
| 191 |
+
original_sampling_rate = 50
|
| 192 |
+
num_classes = 35
|
| 193 |
+
|
| 194 |
+
elif dataset == 'realworld':
|
| 195 |
+
all_X[:,:,14] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
| 196 |
+
all_X[:,:,16] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
| 197 |
+
all_X[:,:,13] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
|
| 198 |
+
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
| 199 |
+
all_X[:,:,1] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
|
| 200 |
+
all_X[:,:,15] = np.concatenate((X[:,:,30:33], X[:,:,33:36]), axis=-1)
|
| 201 |
+
all_X[:,:,9] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
|
| 202 |
+
original_sampling_rate = 50
|
| 203 |
+
num_classes = 8
|
| 204 |
+
|
| 205 |
+
elif dataset == 'TNDA-HAR':
|
| 206 |
+
all_X[:,:,20] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
| 207 |
+
all_X[:,:,2] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
| 208 |
+
all_X[:,:,21] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
|
| 209 |
+
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
| 210 |
+
all_X[:,:,11] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
|
| 211 |
+
original_sampling_rate = 50
|
| 212 |
+
num_classes = 8
|
| 213 |
+
|
| 214 |
+
elif dataset == 'ut-complex':
|
| 215 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
| 216 |
+
all_X[:,:,21] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
| 217 |
+
original_sampling_rate = 50
|
| 218 |
+
num_classes = 13
|
| 219 |
+
|
| 220 |
+
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
|
| 221 |
+
|
| 222 |
+
# resample real data to 20 Hz
|
| 223 |
+
new_sampling_rate = 20
|
| 224 |
+
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
|
| 225 |
+
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
|
| 226 |
+
|
| 227 |
+
# pad real data to args.padding_size
|
| 228 |
+
masks = np.ones_like(resampled_data)
|
| 229 |
+
if resampled_data.shape[1] < padding_size:
|
| 230 |
+
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
|
| 231 |
+
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
|
| 232 |
+
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
|
| 233 |
+
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
|
| 234 |
+
|
| 235 |
+
if split == 'train' and k and k < len(real_inputs):
|
| 236 |
+
real_inputs, real_masks, real_labels = select_samples(real_inputs, real_masks, real_labels, k, dataset, data_path)
|
| 237 |
+
print(real_inputs.shape, real_labels.shape)
|
| 238 |
+
|
| 239 |
+
# load text
|
| 240 |
+
label_dictionary = data['label_dictionary']
|
| 241 |
+
label_list = [' '.join(labels) for labels in label_dictionary.values()]
|
| 242 |
+
all_text = clip.tokenize(label_list).cuda()
|
| 243 |
+
|
| 244 |
+
return real_inputs, real_masks, real_labels, label_list, all_text, num_classes
|
| 245 |
+
|
| 246 |
+
def load_multiple(dataset_list, padding_size, data_path, split='test', k=None):
|
| 247 |
+
|
| 248 |
+
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list = [], [], [], [], [], []
|
| 249 |
+
for dataset in dataset_list:
|
| 250 |
+
real_inputs, real_masks, real_labels, label_list, all_text, num_classes = load(dataset, padding_size, data_path, split, k)
|
| 251 |
+
real_inputs_list.append(real_inputs)
|
| 252 |
+
real_masks_list.append(real_masks)
|
| 253 |
+
real_labels_list.append(real_labels)
|
| 254 |
+
label_list_list.append(label_list)
|
| 255 |
+
all_text_list.append(all_text)
|
| 256 |
+
num_classes_list.append(num_classes)
|
| 257 |
+
|
| 258 |
+
return real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list
|
| 259 |
+
|
| 260 |
+
def load_custom_data(X_path, y_path, config_path, joint_list, original_sampling_rate, padding_size=200, split='test', k=None, few_shot_path=None):
|
| 261 |
+
|
| 262 |
+
X = np.load(X_path)
|
| 263 |
+
real_labels = torch.from_numpy(np.load(y_path))
|
| 264 |
+
with open(config_path, 'r') as file:
|
| 265 |
+
data = json.load(file)
|
| 266 |
+
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
|
| 267 |
+
|
| 268 |
+
for i, joint in enumerate(joint_list):
|
| 269 |
+
all_X[:,:,joint] = np.concatenate((X[:,:,6*i:6*i+3], X[:,:,6*i+3:6*i+6]), axis=-1)
|
| 270 |
+
|
| 271 |
+
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
|
| 272 |
+
|
| 273 |
+
# resample real data to 20 Hz
|
| 274 |
+
new_sampling_rate = 20
|
| 275 |
+
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
|
| 276 |
+
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
|
| 277 |
+
|
| 278 |
+
# pad real data to args.padding_size
|
| 279 |
+
masks = np.ones_like(resampled_data)
|
| 280 |
+
if resampled_data.shape[1] < padding_size:
|
| 281 |
+
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
|
| 282 |
+
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
|
| 283 |
+
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
|
| 284 |
+
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
|
| 285 |
+
|
| 286 |
+
if split == 'train' and k and k < len(real_inputs):
|
| 287 |
+
|
| 288 |
+
unique_labels = torch.unique(real_labels)
|
| 289 |
+
|
| 290 |
+
if few_shot_path is None:
|
| 291 |
+
print('Generating few shot indices ...')
|
| 292 |
+
all_indices = []
|
| 293 |
+
for i, label in enumerate(unique_labels):
|
| 294 |
+
indices = torch.where(real_labels == label)[0]
|
| 295 |
+
selected_indices = indices[torch.randperm(len(indices))[:k]]
|
| 296 |
+
all_indices.append(selected_indices)
|
| 297 |
+
else:
|
| 298 |
+
print('Loading existing few shot indices ...')
|
| 299 |
+
all_indices = torch.load(few_shot_path)
|
| 300 |
+
|
| 301 |
+
selected_data = []
|
| 302 |
+
selected_masks = []
|
| 303 |
+
selected_labels = []
|
| 304 |
+
for i, label in enumerate(unique_labels):
|
| 305 |
+
selected_indices = all_indices[i]
|
| 306 |
+
selected_data.append(real_inputs[selected_indices])
|
| 307 |
+
selected_masks.append(real_masks[selected_indices])
|
| 308 |
+
selected_labels.append(real_labels[selected_indices])
|
| 309 |
+
selected_data = torch.cat(selected_data, dim=0)
|
| 310 |
+
selected_masks = torch.cat(selected_masks, dim=0)
|
| 311 |
+
selected_labels = torch.cat(selected_labels, dim=0)
|
| 312 |
+
real_inputs, real_masks, real_labels = selected_data, selected_masks, selected_labels
|
| 313 |
+
|
| 314 |
+
print(real_inputs.shape, real_labels.shape)
|
| 315 |
+
|
| 316 |
+
# load text
|
| 317 |
+
label_dictionary = data['label_dictionary']
|
| 318 |
+
label_list = [' '.join(labels) for labels in label_dictionary.values()]
|
| 319 |
+
all_text = clip.tokenize(label_list).cuda()
|
| 320 |
+
|
| 321 |
+
return real_inputs, real_masks, real_labels, label_list, all_text
|
evaluate.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
| 7 |
+
import wandb
|
| 8 |
+
import datetime
|
| 9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 10 |
+
|
| 11 |
+
from data import load, load_multiple
|
| 12 |
+
from utils import compute_metrics_np
|
| 13 |
+
from contrastive import ContrastiveModule
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
# load real data
|
| 17 |
+
dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
|
| 18 |
+
'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
|
| 19 |
+
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path)
|
| 20 |
+
test_real_dataloader_list = []
|
| 21 |
+
for real_inputs, real_masks, real_labels in zip(real_inputs_list, real_masks_list, real_labels_list):
|
| 22 |
+
real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
| 23 |
+
test_real_dataloader_list.append(DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False))
|
| 24 |
+
|
| 25 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
| 26 |
+
wandb.init(
|
| 27 |
+
project='UniMTS',
|
| 28 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
model = ContrastiveModule(args).cuda()
|
| 32 |
+
|
| 33 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
| 34 |
+
|
| 35 |
+
model.eval()
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
for ds, real_labels, test_real_dataloader, label_list, all_text in zip(dataset_list, real_labels_list, test_real_dataloader_list, label_list_list, all_text_list):
|
| 38 |
+
pred_whole, logits_whole = [], []
|
| 39 |
+
for input, mask, label in test_real_dataloader:
|
| 40 |
+
|
| 41 |
+
input = input.cuda()
|
| 42 |
+
mask = mask.cuda()
|
| 43 |
+
label = label.cuda()
|
| 44 |
+
|
| 45 |
+
if not args.gyro:
|
| 46 |
+
b, t, c = input.shape
|
| 47 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 48 |
+
input = input[:,:,indices]
|
| 49 |
+
|
| 50 |
+
b, t, c = input.shape
|
| 51 |
+
if args.stft:
|
| 52 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 53 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 54 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 55 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 56 |
+
|
| 57 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 58 |
+
|
| 59 |
+
logits_per_imu, logits_per_text = model(input, all_text)
|
| 60 |
+
logits_whole.append(logits_per_imu)
|
| 61 |
+
|
| 62 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
| 63 |
+
pred_whole.append(pred)
|
| 64 |
+
|
| 65 |
+
pred = np.concatenate(pred_whole)
|
| 66 |
+
acc = accuracy_score(real_labels, pred)
|
| 67 |
+
prec = precision_score(real_labels, pred, average='macro')
|
| 68 |
+
rec = recall_score(real_labels, pred, average='macro')
|
| 69 |
+
f1 = f1_score(real_labels, pred, average='macro')
|
| 70 |
+
|
| 71 |
+
print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
|
| 72 |
+
wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
|
| 73 |
+
|
| 74 |
+
logits_whole = torch.cat(logits_whole)
|
| 75 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
|
| 76 |
+
|
| 77 |
+
print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
| 78 |
+
wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
|
| 82 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
| 83 |
+
|
| 84 |
+
# data
|
| 85 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
| 86 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
| 87 |
+
|
| 88 |
+
# training
|
| 89 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
| 90 |
+
parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
|
| 91 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
| 92 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
| 93 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
| 94 |
+
|
| 95 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
| 96 |
+
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
main(args)
|
evaluate_custom.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
| 7 |
+
import wandb
|
| 8 |
+
import datetime
|
| 9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 10 |
+
|
| 11 |
+
from data import load, load_multiple, load_custom_data
|
| 12 |
+
from utils import compute_metrics_np
|
| 13 |
+
from contrastive import ContrastiveModule
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
# load real data
|
| 17 |
+
|
| 18 |
+
real_inputs, real_masks, real_labels, label_list, all_text = load_custom_data(
|
| 19 |
+
args.X_path, args.y_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
|
| 20 |
+
)
|
| 21 |
+
real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
| 22 |
+
test_real_dataloader = DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False)
|
| 23 |
+
|
| 24 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
| 25 |
+
wandb.init(
|
| 26 |
+
project='UniMTS',
|
| 27 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
model = ContrastiveModule(args).cuda()
|
| 31 |
+
|
| 32 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
| 33 |
+
|
| 34 |
+
model.eval()
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
pred_whole, logits_whole = [], []
|
| 37 |
+
for input, mask, label in test_real_dataloader:
|
| 38 |
+
|
| 39 |
+
input = input.cuda()
|
| 40 |
+
mask = mask.cuda()
|
| 41 |
+
label = label.cuda()
|
| 42 |
+
|
| 43 |
+
if not args.gyro:
|
| 44 |
+
b, t, c = input.shape
|
| 45 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 46 |
+
input = input[:,:,indices]
|
| 47 |
+
|
| 48 |
+
b, t, c = input.shape
|
| 49 |
+
if args.stft:
|
| 50 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 51 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 52 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 53 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 54 |
+
|
| 55 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 56 |
+
|
| 57 |
+
logits_per_imu, logits_per_text = model(input, all_text)
|
| 58 |
+
logits_whole.append(logits_per_imu)
|
| 59 |
+
|
| 60 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
| 61 |
+
pred_whole.append(pred)
|
| 62 |
+
|
| 63 |
+
pred = np.concatenate(pred_whole)
|
| 64 |
+
acc = accuracy_score(real_labels, pred)
|
| 65 |
+
prec = precision_score(real_labels, pred, average='macro')
|
| 66 |
+
rec = recall_score(real_labels, pred, average='macro')
|
| 67 |
+
f1 = f1_score(real_labels, pred, average='macro')
|
| 68 |
+
|
| 69 |
+
print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
|
| 70 |
+
wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
|
| 71 |
+
|
| 72 |
+
logits_whole = torch.cat(logits_whole)
|
| 73 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
|
| 74 |
+
|
| 75 |
+
print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
| 76 |
+
wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
|
| 80 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
| 81 |
+
|
| 82 |
+
# data
|
| 83 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
| 84 |
+
parser.add_argument('--X_path', type=str, required=True, help='/path/to/data/')
|
| 85 |
+
parser.add_argument('--y_path', type=str, required=True, help='/path/to/label/')
|
| 86 |
+
parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
|
| 87 |
+
parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
|
| 88 |
+
parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
|
| 89 |
+
|
| 90 |
+
# training
|
| 91 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
| 92 |
+
parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
|
| 93 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
| 94 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
| 95 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
| 96 |
+
|
| 97 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
| 98 |
+
|
| 99 |
+
args = parser.parse_args()
|
| 100 |
+
|
| 101 |
+
main(args)
|
finetune.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
| 9 |
+
import wandb
|
| 10 |
+
import datetime
|
| 11 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
|
| 14 |
+
from data import load_multiple
|
| 15 |
+
from utils import compute_metrics_np
|
| 16 |
+
from contrastive import ContrastiveModule
|
| 17 |
+
|
| 18 |
+
def main(args):
|
| 19 |
+
|
| 20 |
+
# load real data
|
| 21 |
+
dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
|
| 22 |
+
'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
|
| 23 |
+
train_inputs_list, train_masks_list, train_labels_list, label_list_list, all_text_list, num_classes_list = load_multiple(dataset_list, args.padding_size, args.data_path, split='train', k=args.k)
|
| 24 |
+
test_inputs_list, test_masks_list, test_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path, split='test')
|
| 25 |
+
train_dataloader_list, test_dataloader_list = [], []
|
| 26 |
+
for real_inputs, real_masks, real_labels in zip(train_inputs_list, train_masks_list, train_labels_list):
|
| 27 |
+
train_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
| 28 |
+
train_dataloader_list.append(DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True))
|
| 29 |
+
for real_inputs, real_masks, real_labels in zip(test_inputs_list, test_masks_list, test_labels_list):
|
| 30 |
+
test_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
| 31 |
+
test_dataloader_list.append(DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False))
|
| 32 |
+
|
| 33 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
| 34 |
+
wandb.init(
|
| 35 |
+
project='UniMTS',
|
| 36 |
+
name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
| 40 |
+
|
| 41 |
+
for ds, train_dataloader, test_dataloader, test_labels, label_list, all_text, num_class in \
|
| 42 |
+
zip(dataset_list, train_dataloader_list, test_dataloader_list, test_labels_list, label_list_list, all_text_list, num_classes_list):
|
| 43 |
+
|
| 44 |
+
args.num_class = num_class
|
| 45 |
+
model = ContrastiveModule(args).cuda()
|
| 46 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
| 47 |
+
|
| 48 |
+
if args.mode == 'full' or args.mode == 'probe':
|
| 49 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
| 50 |
+
if args.mode == 'probe':
|
| 51 |
+
for name, param in model.model.named_parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
best_loss = None
|
| 55 |
+
for epoch in range(args.num_epochs):
|
| 56 |
+
|
| 57 |
+
tol_loss = 0
|
| 58 |
+
|
| 59 |
+
model.train()
|
| 60 |
+
for i, (input, mask, label) in enumerate(train_dataloader):
|
| 61 |
+
|
| 62 |
+
input = input.cuda()
|
| 63 |
+
labels = label.cuda()
|
| 64 |
+
|
| 65 |
+
if not args.gyro:
|
| 66 |
+
b, t, c = input.shape
|
| 67 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 68 |
+
input = input[:,:,indices]
|
| 69 |
+
|
| 70 |
+
b, t, c = input.shape
|
| 71 |
+
if args.stft:
|
| 72 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 73 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 74 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 75 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 76 |
+
|
| 77 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 78 |
+
|
| 79 |
+
output = model.classifier(input)
|
| 80 |
+
|
| 81 |
+
loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
|
| 82 |
+
|
| 83 |
+
optimizer.zero_grad()
|
| 84 |
+
loss.backward()
|
| 85 |
+
optimizer.step()
|
| 86 |
+
|
| 87 |
+
tol_loss += len(input) * loss.item()
|
| 88 |
+
|
| 89 |
+
# print(epoch, i, loss.item())
|
| 90 |
+
|
| 91 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
| 92 |
+
wandb.log({'{ds} loss': tol_loss / len(train_dataset)})
|
| 93 |
+
|
| 94 |
+
if best_loss is None or tol_loss < best_loss:
|
| 95 |
+
best_loss = tol_loss
|
| 96 |
+
torch.save(model.state_dict(), os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth'))
|
| 97 |
+
|
| 98 |
+
# evaluation
|
| 99 |
+
model.load_state_dict(torch.load(os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth')))
|
| 100 |
+
model.eval()
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
|
| 103 |
+
pred_whole, logits_whole = [], []
|
| 104 |
+
for input, mask, label in test_dataloader:
|
| 105 |
+
|
| 106 |
+
input = input.cuda()
|
| 107 |
+
label = label.cuda()
|
| 108 |
+
|
| 109 |
+
if not args.gyro:
|
| 110 |
+
b, t, c = input.shape
|
| 111 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 112 |
+
input = input[:,:,indices]
|
| 113 |
+
|
| 114 |
+
b, t, c = input.shape
|
| 115 |
+
if args.stft:
|
| 116 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 117 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 118 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 119 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 120 |
+
|
| 121 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 122 |
+
|
| 123 |
+
logits_per_imu = model.classifier(input)
|
| 124 |
+
logits_whole.append(logits_per_imu)
|
| 125 |
+
|
| 126 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
| 127 |
+
pred_whole.append(pred)
|
| 128 |
+
|
| 129 |
+
pred = np.concatenate(pred_whole)
|
| 130 |
+
acc = accuracy_score(test_labels, pred)
|
| 131 |
+
prec = precision_score(test_labels, pred, average='macro')
|
| 132 |
+
rec = recall_score(test_labels, pred, average='macro')
|
| 133 |
+
f1 = f1_score(test_labels, pred, average='macro')
|
| 134 |
+
|
| 135 |
+
print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
|
| 136 |
+
wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
|
| 137 |
+
|
| 138 |
+
logits_whole = torch.cat(logits_whole)
|
| 139 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
|
| 140 |
+
|
| 141 |
+
print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
| 142 |
+
wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
|
| 147 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
| 148 |
+
|
| 149 |
+
# model
|
| 150 |
+
parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
|
| 151 |
+
|
| 152 |
+
# data
|
| 153 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
| 154 |
+
parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
|
| 155 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
| 156 |
+
|
| 157 |
+
# training
|
| 158 |
+
parser.add_argument('--stage', type=str, default='finetune', help='training stage')
|
| 159 |
+
parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
|
| 160 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
| 161 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
| 162 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
| 163 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
| 164 |
+
|
| 165 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
| 166 |
+
|
| 167 |
+
args = parser.parse_args()
|
| 168 |
+
|
| 169 |
+
main(args)
|
finetune_custom.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
| 9 |
+
import wandb
|
| 10 |
+
import datetime
|
| 11 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
|
| 14 |
+
from data import load_multiple, load_custom_data
|
| 15 |
+
from utils import compute_metrics_np
|
| 16 |
+
from contrastive import ContrastiveModule
|
| 17 |
+
|
| 18 |
+
def main(args):
|
| 19 |
+
|
| 20 |
+
train_inputs, train_masks, train_labels, _, _ = load_custom_data(
|
| 21 |
+
args.X_train_path, args.y_train_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='train', k=args.k, few_shot_path=None
|
| 22 |
+
)
|
| 23 |
+
train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
|
| 24 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
| 25 |
+
|
| 26 |
+
test_inputs, test_masks, test_labels, _, _ = load_custom_data(
|
| 27 |
+
args.X_test_path, args.y_test_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
|
| 28 |
+
)
|
| 29 |
+
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)
|
| 30 |
+
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
|
| 31 |
+
|
| 32 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
| 33 |
+
wandb.init(
|
| 34 |
+
project='UniMTS',
|
| 35 |
+
name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
| 39 |
+
|
| 40 |
+
model = ContrastiveModule(args).cuda()
|
| 41 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
| 42 |
+
|
| 43 |
+
if args.mode == 'full' or args.mode == 'probe':
|
| 44 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
| 45 |
+
if args.mode == 'probe':
|
| 46 |
+
for name, param in model.model.named_parameters():
|
| 47 |
+
param.requires_grad = False
|
| 48 |
+
|
| 49 |
+
best_loss = None
|
| 50 |
+
for epoch in range(args.num_epochs):
|
| 51 |
+
|
| 52 |
+
tol_loss = 0
|
| 53 |
+
|
| 54 |
+
model.train()
|
| 55 |
+
for i, (input, mask, label) in enumerate(train_dataloader):
|
| 56 |
+
|
| 57 |
+
input = input.cuda()
|
| 58 |
+
labels = label.cuda()
|
| 59 |
+
|
| 60 |
+
if not args.gyro:
|
| 61 |
+
b, t, c = input.shape
|
| 62 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 63 |
+
input = input[:,:,indices]
|
| 64 |
+
|
| 65 |
+
b, t, c = input.shape
|
| 66 |
+
if args.stft:
|
| 67 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 68 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 69 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 70 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 71 |
+
|
| 72 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 73 |
+
|
| 74 |
+
output = model.classifier(input)
|
| 75 |
+
|
| 76 |
+
loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
|
| 77 |
+
|
| 78 |
+
optimizer.zero_grad()
|
| 79 |
+
loss.backward()
|
| 80 |
+
optimizer.step()
|
| 81 |
+
|
| 82 |
+
tol_loss += len(input) * loss.item()
|
| 83 |
+
|
| 84 |
+
# print(epoch, i, loss.item())
|
| 85 |
+
|
| 86 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
| 87 |
+
wandb.log({' loss': tol_loss / len(train_dataset)})
|
| 88 |
+
|
| 89 |
+
if best_loss is None or tol_loss < best_loss:
|
| 90 |
+
best_loss = tol_loss
|
| 91 |
+
torch.save(model.state_dict(), os.path.join(save_path, f'k={args.k}_best_loss.pth'))
|
| 92 |
+
|
| 93 |
+
# evaluation
|
| 94 |
+
model.load_state_dict(torch.load(os.path.join(save_path, f'k={args.k}_best_loss.pth')))
|
| 95 |
+
model.eval()
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
|
| 98 |
+
pred_whole, logits_whole = [], []
|
| 99 |
+
for input, mask, label in test_dataloader:
|
| 100 |
+
|
| 101 |
+
input = input.cuda()
|
| 102 |
+
label = label.cuda()
|
| 103 |
+
|
| 104 |
+
if not args.gyro:
|
| 105 |
+
b, t, c = input.shape
|
| 106 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 107 |
+
input = input[:,:,indices]
|
| 108 |
+
|
| 109 |
+
b, t, c = input.shape
|
| 110 |
+
if args.stft:
|
| 111 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 112 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 113 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 114 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 115 |
+
|
| 116 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 117 |
+
|
| 118 |
+
logits_per_imu = model.classifier(input)
|
| 119 |
+
logits_whole.append(logits_per_imu)
|
| 120 |
+
|
| 121 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
| 122 |
+
pred_whole.append(pred)
|
| 123 |
+
|
| 124 |
+
pred = np.concatenate(pred_whole)
|
| 125 |
+
acc = accuracy_score(test_labels, pred)
|
| 126 |
+
prec = precision_score(test_labels, pred, average='macro')
|
| 127 |
+
rec = recall_score(test_labels, pred, average='macro')
|
| 128 |
+
f1 = f1_score(test_labels, pred, average='macro')
|
| 129 |
+
|
| 130 |
+
print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
|
| 131 |
+
wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
|
| 132 |
+
|
| 133 |
+
logits_whole = torch.cat(logits_whole)
|
| 134 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
|
| 135 |
+
|
| 136 |
+
print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
| 137 |
+
wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
|
| 142 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
| 143 |
+
|
| 144 |
+
# model
|
| 145 |
+
parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
|
| 146 |
+
|
| 147 |
+
# data
|
| 148 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
| 149 |
+
parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
|
| 150 |
+
parser.add_argument('--X_train_path', type=str, required=True, help='/path/to/train/data/')
|
| 151 |
+
parser.add_argument('--X_test_path', type=str, required=True, help='/path/to/test/data/')
|
| 152 |
+
parser.add_argument('--y_train_path', type=str, required=True, help='/path/to/train/label/')
|
| 153 |
+
parser.add_argument('--y_test_path', type=str, required=True, help='/path/to/test/label/')
|
| 154 |
+
parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
|
| 155 |
+
parser.add_argument('--few_shot_path', type=str, help='/path/to/few/shot/indices/')
|
| 156 |
+
parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
|
| 157 |
+
parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
|
| 158 |
+
parser.add_argument('--num_class', type=int, required=True, help='number of classes')
|
| 159 |
+
|
| 160 |
+
# training
|
| 161 |
+
parser.add_argument('--stage', type=str, default='finetune', help='training stage')
|
| 162 |
+
parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
|
| 163 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
| 164 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
| 165 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
| 166 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
| 167 |
+
|
| 168 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
| 169 |
+
|
| 170 |
+
args = parser.parse_args()
|
| 171 |
+
|
| 172 |
+
main(args)
|
model.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class Graph():
|
| 7 |
+
""" The Graph to model the skeletons
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
strategy (string): must be one of the follow candidates
|
| 11 |
+
- uniform: Uniform Labeling
|
| 12 |
+
- distance: Distance Partitioning
|
| 13 |
+
- spatial: Spatial Configuration
|
| 14 |
+
max_hop (int): the maximal distance between two connected nodes
|
| 15 |
+
dilation (int): controls the spacing between the kernel points
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self,
|
| 19 |
+
strategy='spatial',
|
| 20 |
+
max_hop=1,
|
| 21 |
+
dilation=1):
|
| 22 |
+
self.max_hop = max_hop
|
| 23 |
+
self.dilation = dilation
|
| 24 |
+
|
| 25 |
+
self.get_edge()
|
| 26 |
+
self.hop_dis = get_hop_distance(self.num_node,
|
| 27 |
+
self.edge,
|
| 28 |
+
max_hop=max_hop)
|
| 29 |
+
self.get_adjacency(strategy)
|
| 30 |
+
|
| 31 |
+
def __str__(self):
|
| 32 |
+
return self.A
|
| 33 |
+
|
| 34 |
+
def get_edge(self):
|
| 35 |
+
# edge is a list of [child, parent] paris
|
| 36 |
+
self.num_node = 22
|
| 37 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
| 38 |
+
neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \
|
| 39 |
+
(13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)]
|
| 40 |
+
self.edge = self_link + neighbor_link
|
| 41 |
+
self.center = 0
|
| 42 |
+
|
| 43 |
+
def get_adjacency(self, strategy):
|
| 44 |
+
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
| 45 |
+
adjacency = np.zeros((self.num_node, self.num_node))
|
| 46 |
+
for hop in valid_hop:
|
| 47 |
+
adjacency[self.hop_dis == hop] = 1
|
| 48 |
+
normalize_adjacency = normalize_digraph(adjacency)
|
| 49 |
+
|
| 50 |
+
if strategy == 'uniform':
|
| 51 |
+
A = np.zeros((1, self.num_node, self.num_node))
|
| 52 |
+
A[0] = normalize_adjacency
|
| 53 |
+
self.A = A
|
| 54 |
+
elif strategy == 'distance':
|
| 55 |
+
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
| 56 |
+
for i, hop in enumerate(valid_hop):
|
| 57 |
+
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
|
| 58 |
+
hop]
|
| 59 |
+
self.A = A
|
| 60 |
+
elif strategy == 'spatial':
|
| 61 |
+
A = []
|
| 62 |
+
for hop in valid_hop:
|
| 63 |
+
a_root = np.zeros((self.num_node, self.num_node))
|
| 64 |
+
a_close = np.zeros((self.num_node, self.num_node))
|
| 65 |
+
a_further = np.zeros((self.num_node, self.num_node))
|
| 66 |
+
for i in range(self.num_node):
|
| 67 |
+
for j in range(self.num_node):
|
| 68 |
+
if self.hop_dis[j, i] == hop:
|
| 69 |
+
if self.hop_dis[j, self.center] == self.hop_dis[
|
| 70 |
+
i, self.center]:
|
| 71 |
+
a_root[j, i] = normalize_adjacency[j, i]
|
| 72 |
+
elif self.hop_dis[j, self.center] > self.hop_dis[
|
| 73 |
+
i, self.center]:
|
| 74 |
+
a_close[j, i] = normalize_adjacency[j, i]
|
| 75 |
+
else:
|
| 76 |
+
a_further[j, i] = normalize_adjacency[j, i]
|
| 77 |
+
if hop == 0:
|
| 78 |
+
A.append(a_root)
|
| 79 |
+
else:
|
| 80 |
+
A.append(a_root + a_close)
|
| 81 |
+
A.append(a_further)
|
| 82 |
+
A = np.stack(A)
|
| 83 |
+
self.A = A
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("Do Not Exist This Strategy")
|
| 86 |
+
|
| 87 |
+
def get_hop_distance(num_node, edge, max_hop=1):
|
| 88 |
+
A = np.zeros((num_node, num_node))
|
| 89 |
+
for i, j in edge:
|
| 90 |
+
A[j, i] = 1
|
| 91 |
+
A[i, j] = 1
|
| 92 |
+
|
| 93 |
+
# compute hop steps
|
| 94 |
+
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
| 95 |
+
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
|
| 96 |
+
arrive_mat = (np.stack(transfer_mat) > 0)
|
| 97 |
+
for d in range(max_hop, -1, -1):
|
| 98 |
+
hop_dis[arrive_mat[d]] = d
|
| 99 |
+
return hop_dis
|
| 100 |
+
|
| 101 |
+
def normalize_digraph(A):
|
| 102 |
+
Dl = np.sum(A, 0)
|
| 103 |
+
num_node = A.shape[0]
|
| 104 |
+
Dn = np.zeros((num_node, num_node))
|
| 105 |
+
for i in range(num_node):
|
| 106 |
+
if Dl[i] > 0:
|
| 107 |
+
Dn[i, i] = Dl[i]**(-1)
|
| 108 |
+
AD = np.dot(A, Dn)
|
| 109 |
+
return AD
|
| 110 |
+
|
| 111 |
+
def normalize_undigraph(A):
|
| 112 |
+
Dl = np.sum(A, 0)
|
| 113 |
+
num_node = A.shape[0]
|
| 114 |
+
Dn = np.zeros((num_node, num_node))
|
| 115 |
+
for i in range(num_node):
|
| 116 |
+
if Dl[i] > 0:
|
| 117 |
+
Dn[i, i] = Dl[i]**(-0.5)
|
| 118 |
+
DAD = np.dot(np.dot(Dn, A), Dn)
|
| 119 |
+
return DAD
|
| 120 |
+
|
| 121 |
+
def zero(x):
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
def iden(x):
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
class ConvTemporalGraphical(nn.Module):
|
| 128 |
+
r"""The basic module for applying a graph convolution.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
in_channels (int): Number of channels in the input sequence data
|
| 132 |
+
out_channels (int): Number of channels produced by the convolution
|
| 133 |
+
kernel_size (int): Size of the graph convolving kernel
|
| 134 |
+
t_kernel_size (int): Size of the temporal convolving kernel
|
| 135 |
+
t_stride (int, optional): Stride of the temporal convolution. Default: 1
|
| 136 |
+
t_padding (int, optional): Temporal zero-padding added to both sides of
|
| 137 |
+
the input. Default: 0
|
| 138 |
+
t_dilation (int, optional): Spacing between temporal kernel elements.
|
| 139 |
+
Default: 1
|
| 140 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output.
|
| 141 |
+
Default: ``True``
|
| 142 |
+
|
| 143 |
+
Shape:
|
| 144 |
+
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
| 145 |
+
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
| 146 |
+
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
| 147 |
+
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
| 148 |
+
|
| 149 |
+
where
|
| 150 |
+
:math:`N` is a batch size,
|
| 151 |
+
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
|
| 152 |
+
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
| 153 |
+
:math:`V` is the number of graph nodes.
|
| 154 |
+
"""
|
| 155 |
+
def __init__(self,
|
| 156 |
+
in_channels,
|
| 157 |
+
out_channels,
|
| 158 |
+
kernel_size,
|
| 159 |
+
t_kernel_size=1,
|
| 160 |
+
t_stride=1,
|
| 161 |
+
t_padding=0,
|
| 162 |
+
t_dilation=1,
|
| 163 |
+
bias=True):
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
self.kernel_size = kernel_size
|
| 167 |
+
self.conv = nn.Conv2d(in_channels,
|
| 168 |
+
out_channels * kernel_size,
|
| 169 |
+
kernel_size=(t_kernel_size, 1),
|
| 170 |
+
padding=(t_padding, 0),
|
| 171 |
+
stride=(t_stride, 1),
|
| 172 |
+
dilation=(t_dilation, 1),
|
| 173 |
+
bias=bias)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, A):
|
| 176 |
+
assert A.size(0) == self.kernel_size
|
| 177 |
+
|
| 178 |
+
x = self.conv(x)
|
| 179 |
+
|
| 180 |
+
n, kc, t, v = x.size()
|
| 181 |
+
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
|
| 182 |
+
x = torch.einsum('nkctv,kvw->nctw', (x, A))
|
| 183 |
+
|
| 184 |
+
return x.contiguous(), A
|
| 185 |
+
|
| 186 |
+
class st_gcn_block(nn.Module):
|
| 187 |
+
r"""Applies a spatial temporal graph convolution over an input graph sequence.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
in_channels (int): Number of channels in the input sequence data
|
| 191 |
+
out_channels (int): Number of channels produced by the convolution
|
| 192 |
+
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
|
| 193 |
+
stride (int, optional): Stride of the temporal convolution. Default: 1
|
| 194 |
+
dropout (int, optional): Dropout rate of the final output. Default: 0
|
| 195 |
+
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
|
| 196 |
+
|
| 197 |
+
Shape:
|
| 198 |
+
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
| 199 |
+
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
| 200 |
+
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
| 201 |
+
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
| 202 |
+
|
| 203 |
+
where
|
| 204 |
+
:math:`N` is a batch size,
|
| 205 |
+
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
|
| 206 |
+
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
| 207 |
+
:math:`V` is the number of graph nodes.
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
def __init__(self,
|
| 211 |
+
in_channels,
|
| 212 |
+
out_channels,
|
| 213 |
+
kernel_size,
|
| 214 |
+
stride=1,
|
| 215 |
+
dropout=0,
|
| 216 |
+
residual=True):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
assert len(kernel_size) == 2
|
| 220 |
+
assert kernel_size[0] % 2 == 1
|
| 221 |
+
padding = ((kernel_size[0] - 1) // 2, 0)
|
| 222 |
+
|
| 223 |
+
self.gcn = ConvTemporalGraphical(in_channels, out_channels,
|
| 224 |
+
kernel_size[1])
|
| 225 |
+
|
| 226 |
+
self.tcn = nn.Sequential(
|
| 227 |
+
nn.BatchNorm2d(out_channels),
|
| 228 |
+
nn.ReLU(inplace=True),
|
| 229 |
+
nn.Conv2d(
|
| 230 |
+
out_channels,
|
| 231 |
+
out_channels,
|
| 232 |
+
(kernel_size[0], 1),
|
| 233 |
+
(stride, 1),
|
| 234 |
+
padding,
|
| 235 |
+
),
|
| 236 |
+
nn.BatchNorm2d(out_channels),
|
| 237 |
+
nn.Dropout(dropout, inplace=True),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if not residual:
|
| 241 |
+
self.residual = zero
|
| 242 |
+
|
| 243 |
+
elif (in_channels == out_channels) and (stride == 1):
|
| 244 |
+
self.residual = iden
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
self.residual = nn.Sequential(
|
| 248 |
+
nn.Conv2d(in_channels,
|
| 249 |
+
out_channels,
|
| 250 |
+
kernel_size=1,
|
| 251 |
+
stride=(stride, 1)),
|
| 252 |
+
nn.BatchNorm2d(out_channels),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.relu = nn.ReLU(inplace=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, A):
|
| 258 |
+
|
| 259 |
+
res = self.residual(x)
|
| 260 |
+
x, A = self.gcn(x, A)
|
| 261 |
+
x = self.tcn(x) + res
|
| 262 |
+
|
| 263 |
+
return self.relu(x), A
|
| 264 |
+
|
| 265 |
+
class ST_GCN_18(nn.Module):
|
| 266 |
+
r"""Spatial temporal graph convolutional networks.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
in_channels (int): Number of channels in the input data
|
| 270 |
+
num_class (int): Number of classes for the classification task
|
| 271 |
+
graph_cfg (dict): The arguments for building the graph
|
| 272 |
+
edge_importance_weighting (bool): If ``True``, adds a learnable
|
| 273 |
+
importance weighting to the edges of the graph
|
| 274 |
+
**kwargs (optional): Other parameters for graph convolution units
|
| 275 |
+
|
| 276 |
+
Shape:
|
| 277 |
+
- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
|
| 278 |
+
- Output: :math:`(N, num_class)` where
|
| 279 |
+
:math:`N` is a batch size,
|
| 280 |
+
:math:`T_{in}` is a length of input sequence,
|
| 281 |
+
:math:`V_{in}` is the number of graph nodes,
|
| 282 |
+
:math:`M_{in}` is the number of instance in a frame.
|
| 283 |
+
"""
|
| 284 |
+
def __init__(self,
|
| 285 |
+
in_channels,
|
| 286 |
+
edge_importance_weighting=True,
|
| 287 |
+
data_bn=True,
|
| 288 |
+
**kwargs):
|
| 289 |
+
super().__init__()
|
| 290 |
+
|
| 291 |
+
# load graph
|
| 292 |
+
self.graph = Graph()
|
| 293 |
+
A = torch.tensor(self.graph.A,
|
| 294 |
+
dtype=torch.float32,
|
| 295 |
+
requires_grad=False)
|
| 296 |
+
self.register_buffer('A', A)
|
| 297 |
+
|
| 298 |
+
# build networks
|
| 299 |
+
spatial_kernel_size = A.size(0)
|
| 300 |
+
temporal_kernel_size = 9
|
| 301 |
+
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
| 302 |
+
self.data_bn = nn.BatchNorm1d(in_channels *
|
| 303 |
+
A.size(1)) if data_bn else iden
|
| 304 |
+
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
|
| 305 |
+
self.st_gcn_networks = nn.ModuleList((
|
| 306 |
+
st_gcn_block(in_channels,
|
| 307 |
+
64,
|
| 308 |
+
kernel_size,
|
| 309 |
+
1,
|
| 310 |
+
residual=False,
|
| 311 |
+
**kwargs0),
|
| 312 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
| 313 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
| 314 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
| 315 |
+
st_gcn_block(64, 128, kernel_size, 2, **kwargs),
|
| 316 |
+
st_gcn_block(128, 128, kernel_size, 1, **kwargs),
|
| 317 |
+
st_gcn_block(128, 128, kernel_size, 1, **kwargs),
|
| 318 |
+
st_gcn_block(128, 256, kernel_size, 2, **kwargs),
|
| 319 |
+
st_gcn_block(256, 256, kernel_size, 1, **kwargs),
|
| 320 |
+
st_gcn_block(256, 512, kernel_size, 1, **kwargs),
|
| 321 |
+
))
|
| 322 |
+
|
| 323 |
+
# initialize parameters for edge importance weighting
|
| 324 |
+
if edge_importance_weighting:
|
| 325 |
+
self.edge_importance = nn.ParameterList([
|
| 326 |
+
nn.Parameter(torch.ones(self.A.size()))
|
| 327 |
+
for i in self.st_gcn_networks
|
| 328 |
+
])
|
| 329 |
+
else:
|
| 330 |
+
self.edge_importance = [1] * len(self.st_gcn_networks)
|
| 331 |
+
|
| 332 |
+
def forward(self, x):
|
| 333 |
+
# data normalization
|
| 334 |
+
N, C, T, V, M = x.size()
|
| 335 |
+
x = x.permute(0, 4, 3, 1, 2).contiguous()
|
| 336 |
+
x = x.view(N * M, V * C, T)
|
| 337 |
+
x = self.data_bn(x)
|
| 338 |
+
x = x.view(N, M, V, C, T)
|
| 339 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 340 |
+
x = x.view(N * M, C, T, V)
|
| 341 |
+
|
| 342 |
+
# forward
|
| 343 |
+
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
|
| 344 |
+
x, _ = gcn(x, self.A * importance)
|
| 345 |
+
|
| 346 |
+
# global pooling
|
| 347 |
+
x = F.avg_pool2d(x, x.size()[2:]) # (b, 512, t, joint)
|
| 348 |
+
x = x.view(N, M, -1, 1, 1).mean(dim=1)
|
| 349 |
+
|
| 350 |
+
return x
|
pos2bvh.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from Quaternions import Quaternions
|
| 3 |
+
from scipy_motion import myBVH
|
| 4 |
+
import BVH
|
| 5 |
+
from scipy_motion import myAnimation
|
| 6 |
+
import Animation
|
| 7 |
+
from scipy_motion import myInverseKinematics as myIK
|
| 8 |
+
import InverseKinematics as IK
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import multiprocessing
|
| 11 |
+
import os
|
| 12 |
+
import os.path as osp
|
| 13 |
+
from scipy.spatial.transform import Rotation as R
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
|
| 17 |
+
# names = ['root','leftleg1','leftleg2','leftleg3','leftleg4','rightleg1','rightleg2','rightleg3','rightleg4',\
|
| 18 |
+
# 'spline1','spline2','spline3','spline4','spline5','rightarm1','rightarm2','rightarm3','rightarm4',\
|
| 19 |
+
# 'leftarm1','lertarm2','leftarm3','leftarm4']
|
| 20 |
+
|
| 21 |
+
def process_file(f):
|
| 22 |
+
|
| 23 |
+
fk_positions = np.load('/path/to/joint/pos/%s.npy' % (f))
|
| 24 |
+
|
| 25 |
+
frametime = 1 / 20
|
| 26 |
+
|
| 27 |
+
anim_ik, _, _, save_file = IK.animation_from_positions(fk_positions, parents=parents)
|
| 28 |
+
|
| 29 |
+
if save_file:
|
| 30 |
+
BVH.save('bvh/%s.bvh' % f, anim_ik, frametime=frametime)
|
| 31 |
+
|
| 32 |
+
source_dir = '/path/to/joint/pos'
|
| 33 |
+
error_file = ['M005836.npy', 'M000990.npy', '000990.npy', '005836.npy']
|
| 34 |
+
npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.npy') and file not in error_file]
|
| 35 |
+
|
| 36 |
+
# Process files in parallel
|
| 37 |
+
pool = multiprocessing.Pool(processes=8)
|
| 38 |
+
for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
|
| 39 |
+
pass
|
| 40 |
+
pool.close()
|
| 41 |
+
pool.join()
|
pretrain.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import clip
|
| 10 |
+
import wandb
|
| 11 |
+
import datetime
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
|
| 14 |
+
from data import CLIPDataset
|
| 15 |
+
from utils import augment_data
|
| 16 |
+
from contrastive import ContrastiveModule
|
| 17 |
+
|
| 18 |
+
def main(args):
|
| 19 |
+
|
| 20 |
+
train_dataset = CLIPDataset(args)
|
| 21 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
| 22 |
+
|
| 23 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
| 24 |
+
wandb.init(
|
| 25 |
+
project='UniMTS',
|
| 26 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
model = ContrastiveModule(args).cuda()
|
| 30 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
| 31 |
+
|
| 32 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
| 33 |
+
if not os.path.exists(save_path):
|
| 34 |
+
os.makedirs(save_path)
|
| 35 |
+
|
| 36 |
+
for epoch in range(args.num_epochs):
|
| 37 |
+
|
| 38 |
+
tol_loss = 0
|
| 39 |
+
|
| 40 |
+
model.train()
|
| 41 |
+
for i, batch in enumerate(train_loader):
|
| 42 |
+
|
| 43 |
+
inputs_imu = batch['imu'].float().cuda()
|
| 44 |
+
inputs_text = clip.tokenize(batch['text'], truncate=True).cuda()
|
| 45 |
+
mask = batch['mask'].float().cuda()
|
| 46 |
+
|
| 47 |
+
input = inputs_imu * mask
|
| 48 |
+
|
| 49 |
+
# rotation invariant
|
| 50 |
+
if args.aug:
|
| 51 |
+
input = augment_data(input)
|
| 52 |
+
|
| 53 |
+
if not args.gyro:
|
| 54 |
+
b, t, c = input.shape
|
| 55 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
| 56 |
+
input = input[:,:,indices]
|
| 57 |
+
|
| 58 |
+
b, t, c = input.shape
|
| 59 |
+
if args.stft:
|
| 60 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
| 61 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
| 62 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
| 63 |
+
input = torch.cat((input, input_stft), dim=-1)
|
| 64 |
+
|
| 65 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
| 66 |
+
|
| 67 |
+
# IMU and text representations
|
| 68 |
+
logits_per_imu, logits_per_text = model(input, inputs_text)
|
| 69 |
+
|
| 70 |
+
# positive keys are the entries on the diagonal
|
| 71 |
+
labels = torch.arange(len(batch['imu'])).cuda()
|
| 72 |
+
|
| 73 |
+
loss = F.cross_entropy(logits_per_imu / args.temperature, labels, reduction="mean")
|
| 74 |
+
|
| 75 |
+
optimizer.zero_grad()
|
| 76 |
+
loss.backward()
|
| 77 |
+
optimizer.step()
|
| 78 |
+
|
| 79 |
+
tol_loss += len(inputs_imu) * loss.item()
|
| 80 |
+
|
| 81 |
+
# print(epoch, i, loss.item())
|
| 82 |
+
|
| 83 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
| 84 |
+
wandb.log({'loss': tol_loss / len(train_dataset)})
|
| 85 |
+
|
| 86 |
+
if epoch > 0 and epoch % args.log == 0:
|
| 87 |
+
torch.save(model.model.state_dict(), os.path.join(save_path, f'epoch_{epoch}.pth'))
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
|
| 91 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
| 92 |
+
|
| 93 |
+
# data
|
| 94 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
| 95 |
+
parser.add_argument('--sample', type=float, default='1', help='pre-training down-sample ratio (default: 1)')
|
| 96 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
| 97 |
+
|
| 98 |
+
# training
|
| 99 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
| 100 |
+
parser.add_argument('--stage', type=str, default='pretrain', help='training stage')
|
| 101 |
+
parser.add_argument('--num_epochs', type=int, default=100, help='number of pre-training epochs')
|
| 102 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
| 103 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
| 104 |
+
parser.add_argument('--aug', type=int, default=1, help='using augmentation or not')
|
| 105 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
| 106 |
+
parser.add_argument('--temperature', type=float, default=0.1, help='temperature')
|
| 107 |
+
parser.add_argument('--log', type=int, default=10, help='logging step')
|
| 108 |
+
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
|
| 111 |
+
main(args)
|
run_evaluation.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python evaluate.py \
|
| 2 |
+
--batch_size 64 \
|
| 3 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 4 |
+
--data_path 'UniMTS_data'
|
run_evaluation_custom.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python evaluate_custom.py \
|
| 2 |
+
--batch_size 64 \
|
| 3 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 4 |
+
--X_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
| 5 |
+
--y_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
| 6 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
| 7 |
+
--joint_list 20 2 21 3 11 \
|
| 8 |
+
--original_sampling_rate 50
|
run_finetune.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
for k in 1 2 3 5 10
|
| 2 |
+
do
|
| 3 |
+
|
| 4 |
+
python finetune.py \
|
| 5 |
+
--mode full \
|
| 6 |
+
--k $k \
|
| 7 |
+
--batch_size 64 \
|
| 8 |
+
--num_epochs 200 \
|
| 9 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 10 |
+
--data_path 'UniMTS_data'
|
| 11 |
+
|
| 12 |
+
done
|
| 13 |
+
|
| 14 |
+
python finetune.py \
|
| 15 |
+
--mode full \
|
| 16 |
+
--batch_size 64 \
|
| 17 |
+
--num_epochs 200 \
|
| 18 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 19 |
+
--data_path 'UniMTS_data'
|
run_finetune_custom.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
for k in 1 2 3 5 10
|
| 2 |
+
do
|
| 3 |
+
|
| 4 |
+
python finetune_custom.py \
|
| 5 |
+
--mode full \
|
| 6 |
+
--k $k \
|
| 7 |
+
--batch_size 64 \
|
| 8 |
+
--num_epochs 200 \
|
| 9 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 10 |
+
--X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
|
| 11 |
+
--y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
|
| 12 |
+
--X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
| 13 |
+
--y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
| 14 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
| 15 |
+
--joint_list 20 2 21 3 11 \
|
| 16 |
+
--original_sampling_rate 50 \
|
| 17 |
+
--num_class 8
|
| 18 |
+
|
| 19 |
+
done
|
| 20 |
+
|
| 21 |
+
python finetune_custom.py \
|
| 22 |
+
--mode full \
|
| 23 |
+
--batch_size 64 \
|
| 24 |
+
--num_epochs 200 \
|
| 25 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
| 26 |
+
--X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
|
| 27 |
+
--y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
|
| 28 |
+
--X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
| 29 |
+
--y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
| 30 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
| 31 |
+
--joint_list 20 2 21 3 11 \
|
| 32 |
+
--original_sampling_rate 50 \
|
| 33 |
+
--num_class 8
|
run_pretrain.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python pretrain.py \
|
| 2 |
+
--aug 1 \
|
| 3 |
+
--batch_size 64 \
|
| 4 |
+
--data_path 'UniMTS_data'
|
text_aug.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def load_api_key(file_path='api_key.txt'):
|
| 8 |
+
with open(file_path, 'r') as f:
|
| 9 |
+
for line in f:
|
| 10 |
+
if line.startswith('api_key='):
|
| 11 |
+
return line.strip().split('=', 1)[1]
|
| 12 |
+
return None
|
| 13 |
+
|
| 14 |
+
openai.api_key = load_api_key()
|
| 15 |
+
|
| 16 |
+
if openai.api_key is None:
|
| 17 |
+
print("Error: API key not found.")
|
| 18 |
+
exit()
|
| 19 |
+
|
| 20 |
+
files = glob.glob('/path/to/txt')
|
| 21 |
+
aug_dir = '/path/to/output'
|
| 22 |
+
|
| 23 |
+
for f in tqdm(files):
|
| 24 |
+
|
| 25 |
+
file_id = f.split('/')[-1]
|
| 26 |
+
if not os.path.exists(aug_dir + file_id):
|
| 27 |
+
|
| 28 |
+
with open(f, 'r') as file:
|
| 29 |
+
lines = file.readlines()
|
| 30 |
+
|
| 31 |
+
text = []
|
| 32 |
+
for i, l in enumerate(lines):
|
| 33 |
+
text.append(str(i) + ': ')
|
| 34 |
+
text.append((l).split('#')[0].strip())
|
| 35 |
+
if text[-1][-1] != '.':
|
| 36 |
+
text.append('. ')
|
| 37 |
+
else:
|
| 38 |
+
text.append(' ')
|
| 39 |
+
text = ''.join(text)
|
| 40 |
+
|
| 41 |
+
prompt = 'The following one or multiple descriptions are describing the same human activities: '
|
| 42 |
+
prompt += text
|
| 43 |
+
prompt += 'Generate 3 paraphrases to describe the same activities. One in a line in a plain text format ending with \n, without numbering or - at the beginning. Do not generate any other analysis except from the paraphrased descriptions.'
|
| 44 |
+
|
| 45 |
+
response = openai.ChatCompletion.create(
|
| 46 |
+
model="gpt-3.5-turbo",
|
| 47 |
+
messages=[
|
| 48 |
+
{"role": "user", "content": prompt}
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
pred = response.choices[0]['message']['content']
|
| 52 |
+
# res = pred.split('\n')
|
| 53 |
+
|
| 54 |
+
shutil.copy(f, aug_dir)
|
| 55 |
+
with open(aug_dir + file_id, 'a') as log_file:
|
| 56 |
+
log_file.write(pred)
|
| 57 |
+
|
| 58 |
+
files = glob.glob('/path/to/output')
|
| 59 |
+
for f in tqdm(files):
|
| 60 |
+
with open(f, 'r') as file:
|
| 61 |
+
lines = file.readlines()
|
| 62 |
+
|
| 63 |
+
lines = [line.lstrip("- ") for line in lines if line.strip()]
|
| 64 |
+
|
| 65 |
+
with open(f, 'w') as file:
|
| 66 |
+
file.writelines(lines)
|
utils.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import imageio
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
def random_rotation_matrix():
|
| 8 |
+
# Random quaternion
|
| 9 |
+
q = torch.randn(4)
|
| 10 |
+
q = q / torch.norm(q)
|
| 11 |
+
|
| 12 |
+
# Quaternion to rotation matrix
|
| 13 |
+
R = torch.tensor([
|
| 14 |
+
[1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[3]*q[0], 2*q[1]*q[3] + 2*q[2]*q[0]],
|
| 15 |
+
[2*q[1]*q[2] + 2*q[3]*q[0], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[1]*q[0]],
|
| 16 |
+
[2*q[1]*q[3] - 2*q[2]*q[0], 2*q[2]*q[3] + 2*q[1]*q[0], 1 - 2*q[1]**2 - 2*q[2]**2]
|
| 17 |
+
])
|
| 18 |
+
return R
|
| 19 |
+
|
| 20 |
+
def augment_data(data):
|
| 21 |
+
B, T, M = data.shape
|
| 22 |
+
augmented_data = torch.zeros_like(data)
|
| 23 |
+
|
| 24 |
+
for i in range(B):
|
| 25 |
+
for c in range(0, M, 6):
|
| 26 |
+
R = random_rotation_matrix().cuda()
|
| 27 |
+
acc = data[i, :, c:c+3].transpose(0, 1) # Shape (3, T)
|
| 28 |
+
gyro = data[i, :, c+3:c+6].transpose(0, 1) # Shape (3, T)
|
| 29 |
+
|
| 30 |
+
# Apply rotation
|
| 31 |
+
rotated_acc = torch.matmul(R, acc)
|
| 32 |
+
rotated_gyro = torch.matmul(R, gyro)
|
| 33 |
+
|
| 34 |
+
# Concatenate and assign to augmented_data
|
| 35 |
+
augmented_data[i, :, c:c+3] = rotated_acc.transpose(0, 1)
|
| 36 |
+
augmented_data[i, :, c+3:c+6] = rotated_gyro.transpose(0, 1)
|
| 37 |
+
|
| 38 |
+
return augmented_data
|
| 39 |
+
|
| 40 |
+
def update_limits(data):
|
| 41 |
+
# Get global min and max for each axis
|
| 42 |
+
min_x, max_x = np.min(data[:, :, 0]), np.max(data[:, :, 0])
|
| 43 |
+
min_y, max_y = np.min(data[:, :, 2]), np.max(data[:, :, 2])
|
| 44 |
+
min_z, max_z = np.min(data[:, :, 1]), np.max(data[:, :, 1])
|
| 45 |
+
|
| 46 |
+
# Add some padding to ensure the skeleton doesn't touch the plot edges
|
| 47 |
+
padding = 0.1
|
| 48 |
+
x_range = max_x - min_x
|
| 49 |
+
y_range = max_y - min_y
|
| 50 |
+
z_range = max_z - min_z
|
| 51 |
+
|
| 52 |
+
return (min_x - padding * x_range, max_x + padding * x_range), \
|
| 53 |
+
(min_y - padding * y_range, max_y + padding * y_range), \
|
| 54 |
+
(min_z - padding * z_range, max_z + padding * z_range)
|
| 55 |
+
|
| 56 |
+
def plot_skeleton(frame_data, xlims, ylims, zlims, dataset):
|
| 57 |
+
"""
|
| 58 |
+
Plot a single frame of skeleton data.
|
| 59 |
+
"""
|
| 60 |
+
fig = plt.figure()
|
| 61 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 62 |
+
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
|
| 63 |
+
|
| 64 |
+
# Add code here to connect the joints as per your skeleton structure
|
| 65 |
+
if dataset == 't2m':
|
| 66 |
+
connections = [
|
| 67 |
+
[0, 2, 5, 8, 11],
|
| 68 |
+
[0, 1, 4, 7, 10],
|
| 69 |
+
[0, 3, 6, 9, 12, 15],
|
| 70 |
+
[9, 14, 17, 19, 21],
|
| 71 |
+
[9, 13, 16, 18, 20]
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
if dataset == 'kit':
|
| 75 |
+
connections = [
|
| 76 |
+
[0, 11, 12, 13, 14, 15],
|
| 77 |
+
[0, 16, 17, 18, 19, 20],
|
| 78 |
+
[0, 1, 2, 3, 4],
|
| 79 |
+
[3, 5, 6, 7],
|
| 80 |
+
[3, 8, 9, 10]
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
if dataset == 'ntu':
|
| 84 |
+
connections = [
|
| 85 |
+
[0, 12, 13, 14, 15],
|
| 86 |
+
[0, 16, 17, 18, 19],
|
| 87 |
+
[0, 1, 20, 2, 3],
|
| 88 |
+
[20, 4, 5, 6, 7, 21],
|
| 89 |
+
[7, 22],
|
| 90 |
+
[20, 8, 9, 10, 11, 23],
|
| 91 |
+
[11, 24],
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
# Plot the lines for each sequence
|
| 95 |
+
for connection in connections:
|
| 96 |
+
for i in range(len(connection)-1):
|
| 97 |
+
start_joint = connection[i]
|
| 98 |
+
end_joint = connection[i+1]
|
| 99 |
+
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
|
| 100 |
+
[frame_data[start_joint, 2], frame_data[end_joint, 2]],
|
| 101 |
+
[frame_data[start_joint, 1], frame_data[end_joint, 1]])
|
| 102 |
+
|
| 103 |
+
ax.view_init(elev=10, azim=90)
|
| 104 |
+
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
|
| 105 |
+
|
| 106 |
+
ax.set_xlim(xlims)
|
| 107 |
+
ax.set_ylim(ylims)
|
| 108 |
+
ax.set_zlim(zlims)
|
| 109 |
+
ax.set_xlabel('X')
|
| 110 |
+
ax.set_ylabel('Z')
|
| 111 |
+
ax.set_zlabel('Y')
|
| 112 |
+
|
| 113 |
+
# Save the plot to a buffer
|
| 114 |
+
buf = io.BytesIO()
|
| 115 |
+
plt.savefig(buf, format='png')
|
| 116 |
+
buf.seek(0)
|
| 117 |
+
img = imageio.imread(buf)
|
| 118 |
+
buf.close()
|
| 119 |
+
|
| 120 |
+
plt.close(fig) # Close the figure to prevent display
|
| 121 |
+
return img
|
| 122 |
+
|
| 123 |
+
def plot_skeleton_gif(data, dataset):
|
| 124 |
+
xlims, ylims, zlims = update_limits(data)
|
| 125 |
+
images = [plot_skeleton(frame, xlims, ylims, zlims, dataset) for frame in data]
|
| 126 |
+
imageio.mimsave('./skeleton_animation.gif', images, fps=20)
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
def plot_single_skeleton(data, dataset, frame=0):
|
| 130 |
+
|
| 131 |
+
xlims, ylims, zlims = update_limits(data)
|
| 132 |
+
frame_data = data[frame]
|
| 133 |
+
|
| 134 |
+
fig = plt.figure()
|
| 135 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 136 |
+
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
|
| 137 |
+
|
| 138 |
+
# Add code here to connect the joints as per your skeleton structure
|
| 139 |
+
if dataset == 't2m':
|
| 140 |
+
connections = [
|
| 141 |
+
[0, 2, 5, 8, 11],
|
| 142 |
+
[0, 1, 4, 7, 10],
|
| 143 |
+
[0, 3, 6, 9, 12, 15],
|
| 144 |
+
[9, 14, 17, 19, 21],
|
| 145 |
+
[9, 13, 16, 18, 20]
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
if dataset == 'kit':
|
| 149 |
+
connections = [
|
| 150 |
+
[0, 11, 12, 13, 14, 15],
|
| 151 |
+
[0, 16, 17, 18, 19, 20],
|
| 152 |
+
[0, 1, 2, 3, 4],
|
| 153 |
+
[3, 5, 6, 7],
|
| 154 |
+
[3, 8, 9, 10]
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
if dataset == 'ntu':
|
| 158 |
+
connections = [
|
| 159 |
+
[0, 12, 13, 14, 15],
|
| 160 |
+
[0, 16, 17, 18, 19],
|
| 161 |
+
[0, 1, 20, 2, 3],
|
| 162 |
+
[20, 4, 5, 6, 7, 21],
|
| 163 |
+
[7, 22],
|
| 164 |
+
[20, 8, 9, 10, 11, 23],
|
| 165 |
+
[11, 24],
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
# Plot the lines for each sequence
|
| 169 |
+
for connection in connections:
|
| 170 |
+
for i in range(len(connection)-1):
|
| 171 |
+
start_joint = connection[i]
|
| 172 |
+
end_joint = connection[i+1]
|
| 173 |
+
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
|
| 174 |
+
[frame_data[start_joint, 2], frame_data[end_joint, 2]],
|
| 175 |
+
[frame_data[start_joint, 1], frame_data[end_joint, 1]])
|
| 176 |
+
|
| 177 |
+
#ax.view_init(elev=10, azim=90)
|
| 178 |
+
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
|
| 179 |
+
|
| 180 |
+
ax.set_xlim(xlims)
|
| 181 |
+
ax.set_ylim(ylims)
|
| 182 |
+
ax.set_zlim(zlims)
|
| 183 |
+
|
| 184 |
+
ax.set_xlabel('X')
|
| 185 |
+
ax.set_ylabel('Z')
|
| 186 |
+
ax.set_zlabel('Y')
|
| 187 |
+
|
| 188 |
+
plt.savefig('skeleton.pdf', bbox_inches='tight')
|
| 189 |
+
|
| 190 |
+
def compute_height(joints, head_index, l_foot_index, r_foot_index):
|
| 191 |
+
joints = torch.from_numpy(joints)
|
| 192 |
+
left = (joints[:,head_index,1] - joints[:,l_foot_index,1])[0]
|
| 193 |
+
right = (joints[:,head_index,1] - joints[:,r_foot_index,1])[0]
|
| 194 |
+
height = (left + right) / 2
|
| 195 |
+
return height
|
| 196 |
+
|
| 197 |
+
def compute_metrics_np(similarity_matrix, correct_labels):
|
| 198 |
+
|
| 199 |
+
B, _ = similarity_matrix.shape
|
| 200 |
+
|
| 201 |
+
ranked_indices = np.argsort(-similarity_matrix, axis=1)
|
| 202 |
+
|
| 203 |
+
correct_label_ranks = np.array([np.where(ranked_indices[i] == correct_labels[i])[0][0] for i in range(B)]) + 1
|
| 204 |
+
|
| 205 |
+
# Compute R@K
|
| 206 |
+
R_at_1 = np.mean(correct_label_ranks <= 1)
|
| 207 |
+
R_at_2 = np.mean(correct_label_ranks <= 2)
|
| 208 |
+
R_at_3 = np.mean(correct_label_ranks <= 3)
|
| 209 |
+
R_at_4 = np.mean(correct_label_ranks <= 4)
|
| 210 |
+
R_at_5 = np.mean(correct_label_ranks <= 5)
|
| 211 |
+
|
| 212 |
+
# Compute MRR
|
| 213 |
+
MRR = np.mean(1.0 / correct_label_ranks)
|
| 214 |
+
|
| 215 |
+
return R_at_1, R_at_2, R_at_3, R_at_4, R_at_5, MRR
|