代码拉取完成,页面将自动刷新
"""
This example shows how to interact with the Determined PyTorch interface to
build a basic MNIST network.
In the `__init__` method, the model and optimizer are wrapped with `wrap_model`
and `wrap_optimizer`. This model is single-input and single-output.
The methods `train_batch` and `evaluate_batch` define the forward pass
for training and evaluation respectively.
"""
import sys
sys.path.append("/home/LAB/gaoch/science/LFDM")
import boto3
from botocore.client import Config
import torch
import numpy as np
import torch.backends.cudnn as cudnn
import os
import sys
import random
from torch.optim.lr_scheduler import MultiStepLR
from typing import Any, Dict, Sequence
from LFAE.dataset import S3Dataset, FluidDataset, DiffDataset, DoubleShockDataset
import torch
import yaml
from LFAE.modules.AE import AutoEncoder
from determined.pytorch import DataLoader, PyTorchTrial, PyTorchTrialContext, LRScheduler
from typing import Any, Dict, Sequence, Tuple, Union, cast
from LFAE.modules.model import ImagePyramide
CONFIG_PATH = 'home/LAB/gaoch/science/AE_config.yaml'
TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]
class AETrial(PyTorchTrial):
def __init__(self, context: PyTorchTrialContext = None, hparams=None) -> None:
self.context = context
self.hparams = context.get_hparams() if context is not None else hparams
self.model_params = self.hparams['model_params']
self.data_params = self.hparams['data_params']
self.train_params = self.hparams['train_params']
self.data_downloaded = self.hparams['data_downloaded']
self.pyramid = self.context.to_device(ImagePyramide(self.train_params['scales'], self.data_params['num_channels']))
self.model = self.context.wrap_model(
self.context.to_device(AutoEncoder(model_params=self.model_params,data_params=self.data_params,train_params=self.train_params).float())
)
self.optimizer = self.context.wrap_optimizer(
torch.optim.Adam(self.model.parameters(), lr=self.train_params['lr'],
betas=self.train_params['adam_betas'])
)
self.scheduler = self.context.wrap_lr_scheduler(MultiStepLR(self.optimizer, self.train_params['epoch_milestones'],
gamma=0.1, last_epoch=-1),
step_mode = LRScheduler.StepMode.STEP_EVERY_EPOCH)
def build_training_data_loader(self) -> DataLoader:
# dataset = S3Dataset('train', self.data_params['data_path'])
self.s3 = boto3.resource('s3',
endpoint_url='http://192.168.5.174:9000/',
aws_access_key_id='K1DH3djl9BWDEMEv28Ar',
aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
config=Config(signature_version='s3v4'),
region_name='cn')
# self.s3.Bucket('datasets').download_file(f'PDEbench/2D/2DCFD/{self.data_params["data_path"]}.hdf5',"fluid.hdf5")
if self.data_params['data_path'] == 'PDEbench/2D/Incom/2D_diff-react_NA_NA.h5':
self.s3.Bucket('datasets').download_file(self.data_params['data_path'],"fluid.h5")
dataset = DiffDataset('train')
elif self.data_params['data_path'] == 'FluidDataset/doubleshock2npy':
dataset = DoubleShockDataset('train')
self.normalizer = dataset.normalizer
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def build_validation_data_loader(self) -> DataLoader:
# dataset = S3Dataset('val', self.data_params['data_path'])
if self.data_params['data_path'] == 'PDEbench/2D/Incom/2D_diff-react_NA_NA.h5':
dataset = DiffDataset('val')
elif self.data_params['data_path'] == 'FluidDataset/doubleshock2npy':
dataset = DoubleShockDataset('val')
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def train_batch(
self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
batch = self.context.to_device(batch)
# print(batch.shape)
# batch shape : [b, c, nf, H, W]
pred=self.model(batch)
# print(pred.shape)
real_pred=self.normalizer.decode(pred)
loss=self.mse_loss(self.pyramid(batch),self.pyramid(real_pred))
self.context.backward(loss)
self.context.step_optimizer(self.optimizer)
if (epoch_idx+1) % 10 == 0 and batch_idx == 0:
save_path = f'checkpoint_AE/{self.context.get_trial_id()}/{epoch_idx}.pth'
torch.save(self.model.state_dict(), save_path)
self.s3.meta.client.upload_file(save_path, 'gaoch', save_path)
print(f'Checkpoint {save_path} saved')
return {"loss": loss}
def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
batch = self.context.to_device(batch)
self.model.eval()
output = self.model(batch)
pred=self.normalizer.decode(output.unsqueeze).squeeze
validation_loss = torch.nn.functional.mse_loss(pred, batch[:, :, ...]).item()
self.model.train()
return {"validation_loss": validation_loss}
def mse_loss(self,pyramide_real, pyramide_generated):
loss = 0
for scale in self.model_params["scales"]:
real=pyramide_real['prediction_' + str(scale)]
generated=pyramide_generated['prediction_' + str(scale)]
loss += torch.mean((real - generated) ** 2)
return loss / len(self.model_params["scales"])
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
cudnn.enabled = True
cudnn.benchmark = True
setup_seed(42)
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
hparams = config['hyperparameters']
trial = AETrial(hparams=hparams)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。