1 Star 2 Fork 1

wxz / isp

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ISP.py 7.23 KB
一键复制 编辑 原始数据 按行查看 历史
wxz 提交于 2022-01-18 01:40 . 调整
"""
Title
author: wxz
date: 2021-12-28
github: https://github.com/xinzwang
** SoftWare ISP Hyperparameter
-- version: 1.0.0
type: numpy.ndarray()
shape: 20
black level correction params: 2 [black_level, white_level]
demosaik params: 0
white balance params: 3 [gain_1, gain_2, gain_3]
color correction params: 9 [c_00, c_01, c_02, c_10, c_11, c_12, c_31, c_32,c_33]
gamma correction params: 2 [mul, gamma]
tone curve params: 4 [mul, kernel_w, kernel_h, sigma]
"""
import os
import cv2
import glob
import re
import json
import numpy as np
import rawpy
from threading import Lock
from isp.blocks.black_level_correction import black_level_correction
from isp.blocks.color_correction import color_correction
from isp.blocks.demosaik import demosaik
from isp.blocks.gamma_correction import gamma_correction
from isp.blocks.tone_curve import tone_curve
from isp.blocks.white_balance import channel_gain_white_balance
class ISP_base(object):
""" ISP base class """
def set_params(self, p):
pass
def forward(self, raw):
pass
def __len__(self):
pass
class ISP_rawpy(ISP_base):
""" ISP pipline based on rawpy """
def __init__(self):
self.dim_params = 6
self.use_auto_wb = False # (0, 1) 自动白平衡
self.gamma = (2.222, 4.5) # (R, R) gamma矫正
self.exp_shift = 1 # 0.25-8.0 线性比例曝光偏移
self.exp_preserve_highlights = 0.5 # 0.0 -1.0 使用exp_shift增亮图像时保留高光
self.user_sat = 0.0 # R 饱和度调整
self.save_count = 0
self.lock = Lock()
self.p = np.array([0, 2.222, 4.5, 1.0, 0.5, 0.0])
def init_params(self):
out = np.array([0, 2.222, 4.5, 1.0, 0.5, 0.0])
return out
def set_params(self, p):
self.p = p
self.use_auto_wb = ((p[0] > 0.5) if p[0] > 0 else False) if p[0] < 1 else True
self.gamma = (p[1], p[2])
self.exp_shift = (p[3] if p[3] > 0.25 else 0.25) if p[3] < 8 else 8
self.exp_preserve_highlights = (p[4] if p[4] > 0 else 0) if p[4] < 1 else 1
self.user_sat = p[5]
def get_params(self):
out = f'params:{self.p}\n\n'
out += f'use_auto_wb:{self.use_auto_wb}\n'
out += f'gamma:{self.gamma}\n'
out += f'exp_shift:{self.exp_shift}\n'
out += f'exp_preserve_highlights:{self.exp_preserve_highlights}\n'
out += f'user_sat:{self.user_sat}\n'
return out
def forward(self, raw):
rgb = raw.postprocess(
use_auto_wb=self.use_auto_wb,
gamma=self.gamma,
exp_shift=self.exp_shift,
exp_preserve_highlights=self.exp_preserve_highlights,
user_sat=self.user_sat)
return rgb
class ISP(ISP_base):
def __init__(self, save_fig=False):
self.dim_params = 20
self.save_fig = save_fig
return
def forward(self, raw):
raw = raw if isinstance(raw, np.ndarray) else raw.raw_image_visible
# version 1.0.0 p->params
p = self.p
assert raw is not None, print('[ERROR] input raw data is None')
assert p.shape == (20,), print('[ERROR] params shape error. shape:', p.shape)
from matplotlib import pyplot as plt
img = black_level_correction(raw, p[0], p[1])
if self.save_fig:
plt.subplot(2, 3, 1)
plt.title('black_level_correction')
plt.imshow(img)
img = demosaik(img)
if self.save_fig:
plt.subplot(2, 3, 2)
plt.title('demosaik')
plt.imshow(img)
img = channel_gain_white_balance(img, p[2:5])
if self.save_fig:
plt.subplot(2, 3, 3)
plt.title('white balance')
plt.imshow(img)
img = color_correction(img, p[5:14].reshape(3, 3))
if self.save_fig:
plt.subplot(2, 3, 4)
plt.title('color correction')
plt.imshow(img)
img = gamma_correction(img, p[14], p[15])
if self.save_fig:
plt.subplot(2, 3, 5)
plt.title('gamma')
plt.imshow(img)
img = tone_curve(img, p[16], p[17:19], p[19])
if self.save_fig:
plt.subplot(2, 3, 6)
plt.title('tone curve')
plt.imshow(img)
plt.savefig('isp.jpg', dpi=300)
return img * 255
def set_params(self, p):
self.p = p
return
@staticmethod
def imread(raw_path):
# read raw iamge. Using rawpy
out = []
for p in raw_path if isinstance(raw_path, list) else [raw_path]:
raw = rawpy.imread(p)
out.append(raw)
return out
@staticmethod
def init_params():
# init params of isp
SID_RAW_params = np.array(
[512, 16383, 1.9296875, 1.0, 2.26171875, .9020, -.2890, -.0715, -.4535, 1.2436, .2348, -.0934, .1919, .7086,
80.0, 2.2, 0.1, 3.0, 3.0, 1.0])
LOD_RAW_params = np.array(
[2047, 14448, 2.074697256088257, 0.9324925541877747, 1.1760492324829102, .9020, -.2890, -.0715, -.4535,
1.2436, .2348, -.0934, .1919, .7086,
80.0, 2.2, 0.1, 3.0, 3.0, 1.0])
return LOD_RAW_params
@staticmethod
def save_params(p=None, path="./runs/isp", suffix=""):
# save isp params to json file
a = {
"black_level_correction": {
"black_level": p[0],
"white_level": p[1]
},
"white_balance": p[2:5].tolist(),
"color_correction": p[5:14].tolist(),
"gamma_correction": {
"mul": p[14],
"gamma": p[15]
},
"tone_curve": {
"mul": p[16],
"kernel": p[17:19].tolist(),
"sigma": p[19]
}
}
path = path.replace('/', os.sep)
if not path.split('.')[-1] == 'json':
files = glob.glob(str(path + "\\*.json"))
num = [int(re.search('[0-9]+', x.split('.')[-2]).group()) for x in files] # find suffix num
num = 0 if len(num) == 0 else max(num) + 1
path = path + os.sep + "params_" + suffix + str(num) + ".json"
with open(path, 'w') as f:
f.write(json.dumps(a))
return
@staticmethod
def load_params(path):
# load isp params from json file
a = None
with open(path) as f:
a = json.loads(f.read())
p = [
a["black_level_correction"]['black_level'],
a["black_level_correction"]['white_level'],
a["white_balance"],
a["color_correction"],
a["gamma_correction"]["mul"],
a["gamma_correction"]["gamma"],
a["tone_curve"]["mul"],
a["tone_curve"]["kernel"],
a["tone_curve"]["sigma"]
]
p = flatten(p)
p = np.array(p)
return p
# ---------------- utils function ----------------
# Flatten the list to 1-D
flatten = lambda x: [y for l in x for y in flatten(l)] if type(x) is list else [x]
if __name__ == '__main__':
a = ISP()
p = a.init_params()
Python
1
https://gitee.com/xinzwang/isp.git
git@gitee.com:xinzwang/isp.git
xinzwang
isp
isp
master

搜索帮助