修正が必要なコード
以下に,メインとなるファイルと,エラーに関係しそうなモジュールを貼り付けます.
もしほかにコードで引用しているモジュールなど追加情報が必要でしたらお申し付けください.
If you need additional information such as modules this code import, please do not hesitate to tell me.
main_file
#!/usr/bin/env python
# coding: utf-8
# In[1]:
from make_embedding import preprocessing, divide_comments_by_video, get_comment_embedding, \
get_title_desc_embedding, initialize_vgg_19, get_image_embedding_vgg_19, \
cal_cos_sim_video_embedding, cal_attn_weight_embedding, ThumbFrameDataset
from get_common_thumb_frame import get_common_thumb_frame
from make_comment_bilstm import BiLSTM, create_batches
from video_dataset import VideoDataset, collate_fn
# In[2]:
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import pickle
import warnings
from transformers import BertModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset
from torch.multiprocessing import set_start_method
from torchvision import transforms
from torchvision import models
from torchvision.models.vgg import VGG19_Weights
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import pytorch_lightning as pl
# In[3]:
try:
set_start_method('spawn')
except RuntimeError:
pass
# In[4]:
random_state=42
d = 768
max_length = 200
max_epochs = 200
patience = 10
num_workers = 16
n_splits = 5
video_batch_size = 1
lstm_hidden_size = 768//2
j = 0.1
batch_size = 4
comment_batch_size = 64
frame_batch_size = 256
dropout_rate = 0.5
lstm_batch_size = 128
lstm_dropout = 0.1 # 0.1-0.5 大規模データでは過学習のリスクが少ないので小さい値から始められる
lr = 1e-4
input_size = 768
hidden_dim = 128
num_layers = 2 # 1-3がいい
bidirectional = True
num_heads = 2
weight_decay = 1e-5
fig_save_name = 'reproduct_choi'
name='reproduct_choi'
torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", category=FutureWarning)
# In[5]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)
# In[6]:
df = preprocessing(random_state)
# In[7]:
common_ids_list = get_common_thumb_frame(df)
df = df[df['video_id'].isin(common_ids_list)]
# thumbnail: 288,video_frame: 213
# In[8]:
print(len(common_ids_list))
# In[9]:
num_real = len(df[df['label']==1])
num_fake = len(df[df['label']==0])
num_videos = len(df['video_id'].drop_duplicates())
print(f'動画数:{num_videos},データ数:{len(df["label"])}, リアル数:{num_real},フェイク数:{num_fake}')
# In[10]:
##### データ削減 #####
df = df.sample(frac=0.05, random_state=42)
#####################
# In[11]:
df_list = divide_comments_by_video(df)
df_drop = df.drop_duplicates(subset='video_id')
# In[12]:
# df = df_list[0]
# print(len(df))
# comment_embeddings = get_comment_embedding(df, tokenizer, bert_model, max_length, batch_size=128)
# print(comment_embeddings.shape)
# 1229
# torch.Size([1229, 200, 768])
# In[13]:
# title_desc_embeddings = get_title_desc_embedding(df_drop, tokenizer, bert_model, max_length=max_length)
# #### 書き込み #####
# with open(f'pickle/title_desc_embeddings_maxlength={max_length}.pkl', 'wb') as f:
# pickle.dump(title_desc_embeddings, f)
# In[14]:
# ##### 読み込み #####
# with open(f'pickle/title_desc_embeddings_maxlength={max_length}.pkl', 'rb') as f:
# title_desc_embeddings = pickle.load(f)
# In[15]:
# print(title_desc_embeddings.shape)
# torch.Size([117, 200, 768])
# In[16]:
def plot_roc_curve(fpr, tpr, random_state, batch_size, max_length, fig_save_name):
plt.figure()
plt.plot(fpr, tpr, label='ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
plt.savefig(filename)
plt.close()
def plot_fpr_threshold(thresholds, fpr, random_state, batch_size, max_length, fig_save_name):
plt.figure()
plt.plot(thresholds, fpr)
plt.xlabel('Thresholds')
plt.ylabel('False Positive Rate')
plt.title('Threshold vs. FPR')
plt.gca().invert_xaxis()
plt.grid(True)
filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
plt.savefig(filename)
plt.close()
def plot_tpr_threshold(thresholds, tpr, random_state, batch_size, max_length, fig_save_name):
plt.figure()
plt.plot(thresholds, tpr)
plt.xlabel('Thresholds')
plt.ylabel('True Positive Rate')
plt.title('Threshold vs. FPR')
plt.gca().invert_xaxis()
plt.grid(True)
filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
plt.savefig(filename)
plt.close()
# In[17]:
from video_dataset import CommentProcessor
# In[18]:
# サブタスク2 タイトル処理
class TitleDescProcessor(nn.Module):
def __init__(self, d):
super(TitleDescProcessor, self).__init__()
self.fc = nn.Linear(d, 2*d)
def forward(self, x):
# x shape: (batch_size, num_titles, max_length, embedding_dim)
# 平均値プーリング
x = torch.mean(x, dim=2)
# x shape: (batch_size, num_titles, embedding_dim)
x = self.fc(x)
# x shape: (batch_size, num_titles, embedding_dim*2)
# 平均値プーリング
x = torch.mean(x, dim=1)
return x
# In[19]:
# ##### 使用例 #####
# df = df_list[0]
# title_desc_embeddings = get_title_desc_embedding(df, tokenizer, bert_model, max_length, batch_size=32)
# title_desc_embeddings = title_desc_embeddings.unsqueeze(0).to('cuda')
# processor = TitleDescProcessor(d=768)
# processor = processor.to('cuda')
# embedding = processor(title_desc_embeddings)
# print(embedding.shape)
# # torch.Size([1, 1536])
# In[20]:
# サブクラス3: 上位j個の類似フレームを取得
class GetJFrames(nn.Module):
def __init__(self, d=768, j=0.1, batch_size=1, frame_batch_size=32):
super(GetJFrames, self).__init__()
self.j = j
self.batch_size = batch_size
self.frame_batch_size = frame_batch_size
self.vgg_19 = initialize_vgg_19(d=d) # vgg_19の初期化
def forward(self, common_ids_list):
self.vgg_19 = self.vgg_19.to('cuda')
dataset = ThumbFrameDataset(common_ids_list)
data_loader = DataLoader(dataset, self.batch_size)
top_j_sim_video_embeddings_list = cal_cos_sim_video_embedding(data_loader, self.vgg_19, self.j, self.frame_batch_size)
return top_j_sim_video_embeddings_list
# In[21]:
# with open('pickle/top_j_sim_video_embeddings_list.pkl', 'wb') as f:
# pickle.dump(top_j_sim_video_embeddings_list, f)
# In[22]:
with open('pickle/top_j_sim_video_embeddings_list.pkl', 'rb') as f:
top_j_sim_video_embeddings_list = pickle.load(f)
# In[23]:
# サブクラス4: ビデオの処理
class VideoProcessor(nn.Module):
def __init__(self, video_batch_size=64, d=768, num_heads=8):
super(VideoProcessor, self).__init__()
self.attention = MultiheadAttention(embed_dim=d*2, num_heads=num_heads, batch_first=True)
self.video_batch_size = video_batch_size
self.video_fc = nn.Linear(2*d, 2*d)
def forward(self, top_j_sim_video_embeddings_list):
self.attention = self.attention.to('cuda')
weighted_avg_video_embeddings = cal_attn_weight_embedding(self.attention, top_j_sim_video_embeddings_list)
video_output = self.video_fc(weighted_avg_video_embeddings)
video_output_avg = torch.mean(video_output, dim=1)
return video_output_avg
# In[24]:
##### 使用例 #####
with open('pickle/top_j_sim_video_embeddings_list.pkl', 'rb') as f:
top_j_sim_video_embeddings_list = pickle.load(f)
processor = VideoProcessor(d=768, num_heads=8)
processor = processor.to('cuda')
weighted_avg_video_embeddings = processor(top_j_sim_video_embeddings_list)
print(weighted_avg_video_embeddings.shape)
# In[25]:
from torch import float16
class FakeNewsDetector(pl.LightningModule):
def __init__(self, tokenizer, bert_model, random_state, max_length, batch_size, num_workers, lr, n_split, dropout_rate, lstm_dropout, input_size, lstm_hidden_size, hidden_dim, num_layers, bidirectional, num_heads, max_epochs, patience, fig_save_name, name, weight_decay, d=768):
super().__init__()
self.save_hyperparameters(ignore=['tokenizer', 'bert_model'])
self.validation_step_outputs = []
self.d = d
self.video_fc = nn.Linear(2*d, 2*d)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(max_length * 2*d, 1024)
self.bn1 = nn.BatchNorm1d(1024)
self.dropout = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, 128)
self.bn3 = nn.BatchNorm1d(128)
self.fc4 = nn.Linear(128, 1)
self.comment_weight = nn.Parameter(torch.randn(1))
self.title_desc_weight = nn.Parameter(torch.randn(1))
self.video_weight = nn.Parameter(torch.randn(1))
print('=====ハイパーパラメータなど完了=====')
self.bilstm_model = BiLSTM(input_size=int(input_size), hidden_size=int(lstm_hidden_size),
num_layers=int(num_layers), dropout=float(lstm_dropout))
self.bilstm_model = self.bilstm_model.to('cuda')
print('=====Bi-LSTM完了=====')
self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size=768//2)
self.title_desc_processor = TitleDescProcessor(d)
self.get_j_frames = GetJFrames()
self.video_processor = VideoProcessor()
print('=====サブクラスなど完了=====')
def forward(self, comment_embeddings, masks_stack, hit_likes, title_desc_embedding, video_output_stack):
comment_output_avg = self.comment_processor(comment_embeddings)
# shape: (batch_size, 2*d)
title_desc_output_avg = self.title_desc_processor(title_desc_embedding)
# shape: (batch_size, 2*d)
top_j_sim_video_embeddings_list = self.get_j_frames(common_ids_list)
video_output_avg = self.video_processor(top_j_sim_video_embeddings_list)
# shape: (1, 2*d)
weights = F.softmax(torch.stack([self.comment_weight, self.title_desc_weight, self.video_weight]), dim=0)
combined_output = weights[0] * comment_output_avg + weights[1] * title_desc_output_avg + weights[2] * video_output_avg
x = self.flatten(combined_output)
x = self.fc1(x)
x = F.relu(x)
x = self.bn1(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.bn2(x)
x = self.dropout(x)
x = self.fc3(x)
x = F.relu(x)
x = self.bn3(x)
x = self.dropout(x)
x = self.fc4(x)
x = torch.sigmoid(x)
x = x.squeeze() # 不要な次元を削除して形状を(batch_size,)にする
return x
def training_step(self, batch, batch_idx):
comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch
output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding)
loss = F.binary_cross_entropy(output, label)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch
output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding)
loss = F.binary_cross_entropy(output, label)
self.log('val_loss', loss)
label_predicted = output
label_predicted = label_predicted.cpu().numpy()
label = label.cpu().numpy()
logits = torch.logit(output).cpu().numpy()
self.validation_step_outputs.append({'label': label, 'label_predicted': label_predicted, 'logits': logits})
accuracy = accuracy_score(label, label_predicted)
f1 = f1_score(label, label_predicted)
precision = precision_score(label, label_predicted)
recall = recall_score(label, label_predicted)
self.log('val_acc', accuracy)
self.log('val_f1', f1)
self.log('val_precision', precision)
self.log('val_recall', recall)
return loss
def on_validation_epoch_end(self):
all_label = []
all_preds = []
all_pred_probs = []
for output in self.validation_step_outputs:
all_label.extend(output['label'])
all_preds.extend(output['logits'])
all_pred_probs.extend(output['label_predicted'])
cm = confusion_matrix(all_label, all_preds)
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
self.log(f'val_cm_{i}_{j}', float(cm[i, j]))
auc = roc_auc_score(all_label, all_pred_probs)
self.log('val_AUC', auc)
fpr, tpr, thresholds = roc_curve(all_label, all_pred_probs)
plot_roc_curve(fpr, tpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_roc')
plot_fpr_threshold(thresholds, fpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_fpr')
plot_tpr_threshold(thresholds, tpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_tpr')
self.validation_step_outputs.clear()
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
return optimizer
# In[26]:
def make_trainer(max_epochs, logger, name, patience):
early_stop_callback = EarlyStopping(
monitor='val_loss', # 監視する値
min_delta=0.00, # 変化の最小量
patience=patience, # 改善が見られないエポック数
verbose=False,
mode='min' # 'min' は値の減少を監視
)
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
dirpath='model_checkpoints/',
filename=name + '-{epoch:02d}-{val_loss:.2f}'
)
trainer = pl.Trainer(
max_epochs=max_epochs,
devices=1,
accelerator='gpu',
logger=logger,
callbacks=[early_stop_callback, checkpoint_callback],
enable_progress_bar=True,
precision='16-mixed'
)
return trainer
# In[27]:
def make_kfold_dataloaders(df_drop, dataset, n_splits, batch_size, num_workers, pin_memory=True):
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True)
labels = df_drop['label'].to_list()
dataloaders_list = []
for train_indices, val_indices in kfold.split(X=np.zeros(len(dataset)), y=labels):
# Subsetを使って訓練セットと検証セットを作成
train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)
# DataLoaderを作成
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, persistent_workers=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, persistent_workers=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn)
# このイテレーションのデータローダーをリストに追加
dataloaders_list.append((train_loader, val_loader))
return dataloaders_list
# In[28]:
val_losses = []
accuracies, f1_scores, precisions, recalls = [], [], [], []
cms_0_0, cms_0_1, cms_1_0, cms_1_1 = [], [], [], []
aucs = []
# In[29]:
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
# In[30]:
dataset = VideoDataset(df_list, df, tokenizer, bert_model, max_length, comment_batch_size, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size, j, frame_batch_size, num_heads, video_batch_size, d=768)
# In[31]:
dataloaders_list = make_kfold_dataloaders(df_drop, dataset, n_splits, batch_size, num_workers)
# In[32]:
for fold, (train_loader, val_loader) in enumerate(dataloaders_list):
print('========')
print(f"Fold {fold + 1}")
print('========')
model = FakeNewsDetector(tokenizer, bert_model, random_state, max_length, batch_size, num_workers, lr, n_splits, dropout_rate, lstm_dropout, input_size, hidden_dim, lstm_hidden_size, num_layers, bidirectional, num_heads, max_epochs, patience, fig_save_name, name, weight_decay, d=768)
logger = TensorBoardLogger(
save_dir="lightning_logs",
name=name,
version=f"Fold_{fold+1}"
)
trainer = make_trainer(max_epochs, logger, name, patience)
trainer.fit(model, train_loader, val_loader)
val_results = trainer.validate(model, val_loader)
val_losses.append(val_results[0]['val_loss'])
accuracies.append(val_results[0]['val_acc'])
f1_scores.append(val_results[0]['val_f1'])
precisions.append(val_results[0]['val_precision'])
recalls.append(val_results[0]['val_recall'])
cms_0_0.append(val_results[0]['val_cm_0_0'])
cms_0_1.append(val_results[0]['val_cm_0_1'])
cms_1_0.append(val_results[0]['val_cm_1_0'])
cms_1_1.append(val_results[0]['val_cm_1_1'])
aucs.append(val_results[0]['val_AUC'])
video_dataset.py
from make_comment_bilstm import BiLSTM, create_batches
from make_embedding import get_comment_embedding, get_title_desc_embedding, cal_cos_sim_video_embedding, initialize_vgg_19, cal_attn_weight_embedding, ThumbFrameDataset
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from torch.nn import MultiheadAttention
from torch.utils.data import DataLoader
import pytorch_lightning as pl
class CommentProcessor(pl.LightningModule):
def __init__(self, d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size=768//2):
super(CommentProcessor, self).__init__()
self.d = d
# BiLSTMモデルの定義。BiLSTMは双方向のため、隠れ層サイズは2倍になる点に注意。
self.bilstm = BiLSTM(input_size=768, hidden_size=lstm_hidden_size, num_layers=num_layers, dropout=lstm_dropout)
self.lstm_batch_size = lstm_batch_size
# LSTMの出力をさらに変換するための全結合層。2倍の隠れ層サイズから2*d次元へ変換。
self.comment_fc = nn.Linear(2*lstm_hidden_size, 2*d)
# このモデルで使用するデバイスをCUDAに設定。全てのサブモジュールもCUDAへ移動される。
def forward(self, comment_embeddings, masks):
# 入力データとマスクをCUDAデバイスへ明示的に移動。
comment_embeddings = comment_embeddings.to('cuda')
masks = masks.to('cuda')
# バッチ処理用に入力データをバッチに分割。
comment_batches, mask_batches = create_batches(comment_embeddings, masks, self.lstm_batch_size)
# 出力を格納するリストを初期化。
output = []
# 分割したバッチごとに処理を実行。
for comment_batch, mask_batch in zip(comment_batches, mask_batches):
# マスクを使用して、各コメントの有効な長さを計算。
lengths = mask_batch.sum(dim=1).long()
# BiLSTMにバッチを入力し、出力を取得。入力は適切な形状に変形される。
lstm_out = self.bilstm(comment_batch, lengths)
# 出力リストに結果を追加。
output.append(lstm_out)
# 全バッチの出力を結合。
comment_output = torch.cat(output, dim=0)
# 全結合層を通じて出力の次元を調整。
comment_output = self.comment_fc(comment_output)
# コメントごとに平均値を計算し、特徴ベクトルを得る。
comment_output_avg = torch.mean(comment_output, dim=0)
return comment_output_avg
# shape: (batch_size*num_comments*max_length, 2*d)
class TitleDescProcessor(nn.Module):
def __init__(self, d, df, tokenizer, bert_model, max_length, batch_size):
super(TitleDescProcessor, self).__init__()
self.fc = nn.Linear(d, 2*d)
self.df = df
self.tokenizer = tokenizer
self.bert_model = bert_model
self.max_length = max_length
self.batch_size = batch_size
def forward(self, x):
x = get_title_desc_embedding(self.df, self.tokenizer, self.bert_model, self.max_length, self.batch_size)
# x shape: (batch_size, num_titles, max_length, embedding_dim)
# 平均値プーリング
x = torch.mean(x, dim=2)
# x shape: (batch_size, num_titles, embedding_dim)
x = self.fc(x)
# x shape: (batch_size, num_titles, embedding_dim*2)
# 平均値プーリング
x = torch.mean(x, dim=1)
return x
class GetJFrames(nn.Module):
def __init__(self, frame_batch_size, j, video_batch_size, d):
super(GetJFrames, self).__init__()
self.j = j
self.video_batch_size = video_batch_size
self.frame_batch_size = frame_batch_size
self.vgg_19 = initialize_vgg_19(d=d) # vgg_19の初期化
def forward(self, common_ids_list):
self.vgg_19 = self.vgg_19.to('cuda')
dataset = ThumbFrameDataset(common_ids_list)
data_loader = DataLoader(dataset, self.video_batch_size)
top_j_sim_video_embeddings_list = cal_cos_sim_video_embedding(data_loader, self.vgg_19, self.j, self.frame_batch_size)
return top_j_sim_video_embeddings_list
class VideoProcessor(nn.Module):
def __init__(self, video_batch_size, num_heads, d):
super(VideoProcessor, self).__init__()
self.attention = MultiheadAttention(embed_dim=d*2, num_heads=num_heads, batch_first=True)
self.video_batch_size = video_batch_size
self.video_fc = nn.Linear(2*d, 2*d)
def forward(self, top_j_sim_video_embeddings_list):
self.attention = self.attention.to('cuda')
weighted_avg_video_embeddings = cal_attn_weight_embedding(self.attention, top_j_sim_video_embeddings_list)
video_output = self.video_fc(weighted_avg_video_embeddings)
video_output_avg = torch.mean(video_output, dim=1)
return video_output_avg
class VideoDataset(Dataset):
def __init__(self, df_list, df, tokenizer, bert_model, max_length, comment_batch_size, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size, j, frame_batch_size, num_heads, video_batch_size, d=768):
self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size)
self.title_desc_processor = TitleDescProcessor(d, df, tokenizer, bert_model, max_length, batch_size=32)
self.get_j_frames = GetJFrames(j, frame_batch_size, video_batch_size=1, d=768)
self.video_processor = VideoProcessor(d, num_heads, video_batch_size)
self.df_list = df_list
self.tokenizer = tokenizer
self.bert_model = bert_model
self.max_length = max_length
self.comment_batch_size = comment_batch_size
def __len__(self):
return len(self.df_list)
def __getitem__(self, idx):
# 各データをstackして返す
df = self.df_list[idx]
# 動画によってコメント数が違う→バッチサイズが異なる→スタックできない→padding
comment_embeddings = get_comment_embedding(df, self.tokenizer, self.bert_model, self.max_length, self.comment_batch_size)
comment_output_avg = self.comment_processor(comment_embeddings)
hit_likes = torch.tensor(df['like_count'].values, dtype=torch.float16)
title_desc_output_avg = self.title_desc_processor()
top_j_sim_video_embeddings_list = self.get_j_frames()
video_output_avg = self.video_processor(top_j_sim_video_embeddings_list)
label = df['label'].values
label = torch.tensor(label, dtype=torch.float16)
# 自動的にlabelは(batch_size,)の形状にして渡される
return comment_output_avg, hit_likes, title_desc_output_avg, video_output_avg, label
def collate_fn(batch):
# バッチ内の全ての要素からコメントテンソルを取得し、最大のコメント数を計算
max_comments = max([comments.size(0) for comments, _, _, _, _ in batch])
padded_comments = []
masks = []
# バッチ内の各要素に対してパディングとマスク処理を実施
for comments, hit_likes, title_desc_embeddings, video_output, label in batch:
pad_size = max_comments - comments.size(0)
mask = torch.ones(comments.size(0), dtype=torch.bool)
if pad_size > 0:
pad_tensor = torch.zeros(pad_size, comments.size(1), comments.size(2), dtype=comments.dtype)
comments = torch.cat([comments, pad_tensor], dim=0)
pad_mask = torch.zeros(pad_size, dtype=torch.bool)
mask = torch.cat([mask, pad_mask], dim=0)
padded_comments.append(comments)
masks.append(mask)
# リストをTensorに変換
padded_comments_stack = torch.stack(padded_comments, dim=0)
masks_stack = torch.stack(masks, dim=0)
hit_likes = torch.stack([hit_likes for _, hit_likes, _, _, _ in batch], dim=0)
title_desc_embeddings = torch.stack([title_desc_embeddings for _, _, title_desc_embeddings, _, _ in batch], dim=0)
video_output_stack = torch.stack([video_output for _, _, _, video_output, _ in batch], dim=0)
labels = torch.stack([label for _, _, _, _, label in batch], dim=0)
return padded_comments_stack, masks_stack, hit_likes, title_desc_embeddings, video_output_stack, labels
この記事が気に入ったらサポートをしてみませんか?