2 Star 1 Fork 1

高崇涵 / LFDM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
trainer_AE.py 5.96 KB
一键复制 编辑 原始数据 按行查看 历史
高崇涵 提交于 2023-09-29 09:11 . 1
"""
This example shows how to interact with the Determined PyTorch interface to
build a basic MNIST network.
In the `__init__` method, the model and optimizer are wrapped with `wrap_model`
and `wrap_optimizer`. This model is single-input and single-output.
The methods `train_batch` and `evaluate_batch` define the forward pass
for training and evaluation respectively.
"""
import sys
sys.path.append("/home/LAB/gaoch/science/LFDM")
import boto3
from botocore.client import Config
import torch
import numpy as np
import torch.backends.cudnn as cudnn
import os
import sys
import random
from torch.optim.lr_scheduler import MultiStepLR
from typing import Any, Dict, Sequence
from LFAE.dataset import S3Dataset, FluidDataset, DiffDataset, DoubleShockDataset
import torch
import yaml
from LFAE.modules.AE import AutoEncoder
from determined.pytorch import DataLoader, PyTorchTrial, PyTorchTrialContext, LRScheduler
from typing import Any, Dict, Sequence, Tuple, Union, cast
from LFAE.modules.model import ImagePyramide
CONFIG_PATH = 'home/LAB/gaoch/science/AE_config.yaml'
TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]
class AETrial(PyTorchTrial):
def __init__(self, context: PyTorchTrialContext = None, hparams=None) -> None:
self.context = context
self.hparams = context.get_hparams() if context is not None else hparams
self.model_params = self.hparams['model_params']
self.data_params = self.hparams['data_params']
self.train_params = self.hparams['train_params']
self.data_downloaded = self.hparams['data_downloaded']
self.pyramid = self.context.to_device(ImagePyramide(self.train_params['scales'], self.data_params['num_channels']))
self.model = self.context.wrap_model(
self.context.to_device(AutoEncoder(model_params=self.model_params,data_params=self.data_params,train_params=self.train_params).float())
)
self.optimizer = self.context.wrap_optimizer(
torch.optim.Adam(self.model.parameters(), lr=self.train_params['lr'],
betas=self.train_params['adam_betas'])
)
self.scheduler = self.context.wrap_lr_scheduler(MultiStepLR(self.optimizer, self.train_params['epoch_milestones'],
gamma=0.1, last_epoch=-1),
step_mode = LRScheduler.StepMode.STEP_EVERY_EPOCH)
def build_training_data_loader(self) -> DataLoader:
# dataset = S3Dataset('train', self.data_params['data_path'])
self.s3 = boto3.resource('s3',
endpoint_url='http://192.168.5.174:9000/',
aws_access_key_id='K1DH3djl9BWDEMEv28Ar',
aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
config=Config(signature_version='s3v4'),
region_name='cn')
# self.s3.Bucket('datasets').download_file(f'PDEbench/2D/2DCFD/{self.data_params["data_path"]}.hdf5',"fluid.hdf5")
if self.data_params['data_path'] == 'PDEbench/2D/Incom/2D_diff-react_NA_NA.h5':
self.s3.Bucket('datasets').download_file(self.data_params['data_path'],"fluid.h5")
dataset = DiffDataset('train')
elif self.data_params['data_path'] == 'FluidDataset/doubleshock2npy':
dataset = DoubleShockDataset('train')
self.normalizer = dataset.normalizer
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def build_validation_data_loader(self) -> DataLoader:
# dataset = S3Dataset('val', self.data_params['data_path'])
if self.data_params['data_path'] == 'PDEbench/2D/Incom/2D_diff-react_NA_NA.h5':
dataset = DiffDataset('val')
elif self.data_params['data_path'] == 'FluidDataset/doubleshock2npy':
dataset = DoubleShockDataset('val')
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def train_batch(
self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
batch = self.context.to_device(batch)
# print(batch.shape)
# batch shape : [b, c, nf, H, W]
pred=self.model(batch)
# print(pred.shape)
real_pred=self.normalizer.decode(pred)
loss=self.mse_loss(self.pyramid(batch),self.pyramid(real_pred))
self.context.backward(loss)
self.context.step_optimizer(self.optimizer)
if (epoch_idx+1) % 10 == 0 and batch_idx == 0:
save_path = f'checkpoint_AE/{self.context.get_trial_id()}/{epoch_idx}.pth'
torch.save(self.model.state_dict(), save_path)
self.s3.meta.client.upload_file(save_path, 'gaoch', save_path)
print(f'Checkpoint {save_path} saved')
return {"loss": loss}
def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
batch = self.context.to_device(batch)
self.model.eval()
output = self.model(batch)
pred=self.normalizer.decode(output.unsqueeze).squeeze
validation_loss = torch.nn.functional.mse_loss(pred, batch[:, :, ...]).item()
self.model.train()
return {"validation_loss": validation_loss}
def mse_loss(self,pyramide_real, pyramide_generated):
loss = 0
for scale in self.model_params["scales"]:
real=pyramide_real['prediction_' + str(scale)]
generated=pyramide_generated['prediction_' + str(scale)]
loss += torch.mean((real - generated) ** 2)
return loss / len(self.model_params["scales"])
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
cudnn.enabled = True
cudnn.benchmark = True
setup_seed(42)
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
hparams = config['hyperparameters']
trial = AETrial(hparams=hparams)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Marc-Antoine-6258/LFDM.git
git@gitee.com:Marc-Antoine-6258/LFDM.git
Marc-Antoine-6258
LFDM
LFDM
main

搜索帮助

344bd9b3 5694891 D2dac590 5694891