155 lines
3.9 KiB
Python
155 lines
3.9 KiB
Python
|
import os
|
||
|
import json
|
||
|
from groq import Groq
|
||
|
from openai import OpenAI
|
||
|
|
||
|
# groq_key = "###"
|
||
|
openai_key = ""
|
||
|
# openai_org = "###"
|
||
|
|
||
|
# groq_client = Groq(api_key=groq_key)
|
||
|
# open_ai_client = openai.Client(api_key=openai_key, organization=openai_org)
|
||
|
|
||
|
|
||
|
def extract_json_from_end(text):
|
||
|
if "```json" in text:
|
||
|
text = text.split("```json")[1]
|
||
|
text = text.split("```")[0]
|
||
|
ind = len(text) - 1
|
||
|
while text[ind] != "}":
|
||
|
ind -= 1
|
||
|
text = text[: ind + 1]
|
||
|
|
||
|
ind -= 1
|
||
|
cnt = 1
|
||
|
while cnt > 0:
|
||
|
if text[ind] == "}":
|
||
|
cnt += 1
|
||
|
elif text[ind] == "{":
|
||
|
cnt -= 1
|
||
|
ind -= 1
|
||
|
|
||
|
# find comments in the json string (texts between "//" and "\n") and remove them
|
||
|
while True:
|
||
|
ind_comment = text.find("//")
|
||
|
if ind_comment == -1:
|
||
|
break
|
||
|
ind_end = text.find("\n", ind_comment)
|
||
|
text = text[:ind_comment] + text[ind_end + 1:]
|
||
|
|
||
|
# print("提取的json:", text[ind + 1:])
|
||
|
jj = json.loads(text[ind + 1:])
|
||
|
return jj
|
||
|
|
||
|
|
||
|
def extract_list_from_end(text):
|
||
|
ind = len(text) - 1
|
||
|
while text[ind] != "]":
|
||
|
ind -= 1
|
||
|
text = text[: ind + 1]
|
||
|
|
||
|
ind -= 1
|
||
|
cnt = 1
|
||
|
while cnt > 0:
|
||
|
if text[ind] == "]":
|
||
|
cnt += 1
|
||
|
elif text[ind] == "[":
|
||
|
cnt -= 1
|
||
|
ind -= 1
|
||
|
# print("提取的list:", text[ind + 1:])
|
||
|
jj = json.loads(text[ind + 1:])
|
||
|
return jj
|
||
|
|
||
|
|
||
|
def get_response(prompt, model="glm-4-0520"):
|
||
|
# if model == "llama3-70b-8192":
|
||
|
# client = groq_client
|
||
|
# else:
|
||
|
# client = open_ai_client
|
||
|
# client = OpenAI(
|
||
|
# api_key="sk-abb441ab88af458d88862ce9be39f746",
|
||
|
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
|
# )
|
||
|
client = OpenAI(
|
||
|
api_key="56e7d49c4052b32956e403216ce927ca.KttmQZhDU7JVvkU1",
|
||
|
base_url="https://open.bigmodel.cn/api/paas/v4",
|
||
|
)
|
||
|
chat_completion = client.chat.completions.create(
|
||
|
messages=[{"role": "user", "content": prompt}],
|
||
|
temperature=0.8,
|
||
|
top_p=0.8,
|
||
|
model=model,
|
||
|
)
|
||
|
|
||
|
res = chat_completion.choices[0].message.content
|
||
|
return res
|
||
|
|
||
|
|
||
|
def load_state(state_file):
|
||
|
with open(state_file, "r") as f:
|
||
|
state = json.load(f)
|
||
|
return state
|
||
|
|
||
|
|
||
|
def save_state(state, dir):
|
||
|
with open(dir, "w") as f:
|
||
|
json.dump(state, f, indent=4)
|
||
|
|
||
|
|
||
|
def shape_string_to_list(shape_string):
|
||
|
if type(shape_string) is list:
|
||
|
return shape_string
|
||
|
# convert a string like "[N, M, K, 19]" to a list like ['N', 'M', 'K', 19]
|
||
|
shape_string = shape_string.strip()
|
||
|
shape_string = shape_string[1:-1]
|
||
|
shape_list = shape_string.split(",")
|
||
|
shape_list = [x.strip() for x in shape_list]
|
||
|
shape_list = [int(x) if x.isdigit() else x for x in shape_list]
|
||
|
if len(shape_list) == 1 and shape_list[0] == "":
|
||
|
shape_list = []
|
||
|
return shape_list
|
||
|
|
||
|
|
||
|
def extract_equal_sign_closed(text):
|
||
|
ind_1 = text.find("=====")
|
||
|
ind_2 = text.find("=====", ind_1 + 1)
|
||
|
obj = text[ind_1 + 6: ind_2].strip()
|
||
|
return obj
|
||
|
|
||
|
|
||
|
class Logger:
|
||
|
def __init__(self, file):
|
||
|
self.file = file
|
||
|
|
||
|
def log(self, text):
|
||
|
with open(self.file, "a") as f:
|
||
|
f.write(text + "\n")
|
||
|
|
||
|
def reset(self):
|
||
|
with open(self.file, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
|
||
|
def create_state(parent_dir, run_dir):
|
||
|
# with open(os.path.join(parent_dir, "data/params.json"), "r") as f:
|
||
|
# params = json.load(f)
|
||
|
from parameters import get_params
|
||
|
with open("data/desc.txt", "r") as f:
|
||
|
desc = f.read()
|
||
|
params = get_params(desc, check=True)
|
||
|
|
||
|
# data = {}
|
||
|
# for key in params:
|
||
|
# data[key] = params[key]["value"]
|
||
|
# del params[key]["value"]
|
||
|
#
|
||
|
# # save the data file in the run_dir
|
||
|
# with open(os.path.join(run_dir, "data.json"), "w") as f:
|
||
|
# json.dump(data, f, indent=4)
|
||
|
|
||
|
with open(os.path.join(run_dir, "desc.txt"), "r") as f:
|
||
|
desc = f.read()
|
||
|
|
||
|
state = {"description": desc, "parameters": params}
|
||
|
return state
|