1 Star 0 Fork 25

待定90 / deoldify

forked from Gitee 极速下载 / deoldify 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
app.py 2.35 KB
一键复制 编辑 原始数据 按行查看 历史
# import the necessary packages
import os
import sys
import requests
import ssl
from flask import Flask
from flask import request
from flask import jsonify
from flask import send_file
from uuid import uuid4
from os import path
import torch
import fastai
from fasterai.visualize import *
from pathlib import Path
import traceback
torch.backends.cudnn.benchmark=True
image_colorizer = get_image_colorizer(artistic=True)
video_colorizer = get_video_colorizer()
os.environ['CUDA_VISIBLE_DEVICES']='0'
app = Flask(__name__)
# define a predict function as an endpoint
@app.route("/process_image", methods=["POST"])
def process_image():
try:
source_url = request.json["source_url"]
render_factor = int(request.json["render_factor"])
upload_directory = 'upload'
if not os.path.exists(upload_directory):
os.mkdir(upload_directory)
random_filename = str(uuid4()) + '.png'
image_colorizer.plot_transformed_image_from_url(url=source_url, path=os.path.join(upload_directory, random_filename), figsize=(20,20),
render_factor=render_factor, display_render_factor=True, compare=False)
callback = send_file(os.path.join("result_images", random_filename), mimetype='image/jpeg')
return callback
except:
traceback.print_exc()
return {message: 'input error'}, 400
finally:
os.remove(os.path.join("result_images", random_filename))
os.remove(os.path.join("upload", random_filename))
@app.route("/process_video", methods=["POST"])
def process_video():
try:
source_url = request.json["source_url"]
render_factor = int(request.json["render_factor"])
upload_directory = 'upload'
if not os.path.exists(upload_directory):
os.mkdir(upload_directory)
random_filename = str(uuid4()) + '.mp4'
video_path = video_colorizer.colorize_from_url(source_url, random_filename, render_factor)
callback = send_file(os.path.join("video/result/", random_filename), mimetype='application/octet-stream')
return callback
except:
traceback.print_exc()
return {message: 'input error'}, 400
finally:
os.remove(os.path.join("video/result/", random_filename))
if __name__ == '__main__':
port = 5000
host = '0.0.0.0'
app.run(host=host, port=port, threaded=True)
1
https://gitee.com/z90/deoldify.git
git@gitee.com:z90/deoldify.git
z90
deoldify
deoldify
master

搜索帮助