garage-AI/app/app.py
2023-09-29 08:52:00 +02:00

119 lines
3.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, re
import requests, time
from dotenv import load_dotenv
from flask_oidc import OpenIDConnect
from flask import Flask, redirect, render_template, request, url_for
app = Flask(__name__)
model = "ggml-gpt4all-j.bin"
load_dotenv()
host = os.getenv("LOCALAI_HOST")
############################### KEYCLOAK ###############################
# app.config.update({
# # PROD ONLY
# 'SECRET_KEY': 'créer-un-secret-ici',
# 'OIDC_CLIENT_SECRETS': 'client_secrets_prod.json',
# 'OIDC_ID_TOKEN_COOKIE_SECURE': False,
# 'OIDC_REQUIRE_VERIFIED_EMAIL': False,
# 'OIDC_USER_INFO_ENABLED': True,
# 'OIDC_OPENID_REALM': 'gregan',
# 'OIDC_SCOPES': ['openid', 'email', 'profile'],
# 'OIDC_INTROSPECTION_AUTH_METHOD': 'client_secret_post'
# })
# app.config['OVERWRITE_REDIRECT_URI'] = 'https://chat-gpt.domain.tld/oidc_callback'
# oidc = OpenIDConnect(app)
# @app.context_processor
# def inject_oidc_user():
# if oidc.user_loggedin:
# return dict(oidc_user=oidc.user_getfield('email'))
# return dict(oidc_user=None)
# CHAT BOT: GPT-TURBO-3.5
@app.route("/", methods=("GET", "POST"))
# @oidc.require_login
def index():
result = ''
temps = ''
if request.method == "POST":
question = request.form["question"]
temp = request.form["temperature"]
url = "http://" + host + ":8080/v1/chat/completions"
payload = {
# "role": "system", "content": "You are Yoda, the character of Star Wars and you answer the question like he would.",
"model": 'gpt-3.5-turbo',
"messages": [{"role": "user", "content": question}],
"temperature": float(temp)
}
tic = time.perf_counter()
response = requests.post(url, json=payload)
if response.status_code == 200:
result = '<md>' + response.json()['choices'][0]['message']['content'] + '</md>'
# print(result)
# format_code(result)
else:
result = "Erreur de connection avec l'API"
toc = time.perf_counter()
temps = f"temps de réponse: {toc - tic:0.4f} seconds"
return render_template("index.html", result=result, time=temps)
# IMAGE GENERATOR: STABLEDIFFUSION
@app.route("/image", methods=("GET", "POST"))
# @oidc.require_login
def image():
result = ''
temps = ''
if request.method == "POST":
question = request.form["image"]
url = "http://" + host + ":8080/v1/images/generations"
headers = {
"Content-Type": "application/json"
}
data = {
"prompt": question,
"size": "256x256",
"directory": "/tmp"
}
tic = time.perf_counter()
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
########## PROD ONLY ##########
# image_url = response.json()['data'][0]['url'].replace("local-ai", "image.domain.tld").replace("http", "https").replace(":8080", "")
image_url = response.json()['data'][0]['url'].replace("local-ai", "localhost")
else:
result = "Erreur de connection avec l'API"
toc = time.perf_counter()
temps = f"temps de réponse: {toc - tic:0.4f} seconds"
result = '<img src=' + image_url + ' >'
return render_template("image.html", result=result, time=temps)