import os import json from groq import Groq from openai import OpenAI # groq_key = "###" openai_key = "" # openai_org = "###" # groq_client = Groq(api_key=groq_key) # open_ai_client = openai.Client(api_key=openai_key, organization=openai_org) def extract_json_from_end(text): if "```json" in text: text = text.split("```json")[1] text = text.split("```")[0] ind = len(text) - 1 while text[ind] != "}": ind -= 1 text = text[: ind + 1] ind -= 1 cnt = 1 while cnt > 0: if text[ind] == "}": cnt += 1 elif text[ind] == "{": cnt -= 1 ind -= 1 # find comments in the json string (texts between "//" and "\n") and remove them while True: ind_comment = text.find("//") if ind_comment == -1: break ind_end = text.find("\n", ind_comment) text = text[:ind_comment] + text[ind_end + 1:] # print("提取的json:", text[ind + 1:]) jj = json.loads(text[ind + 1:]) return jj def extract_list_from_end(text): ind = len(text) - 1 while text[ind] != "]": ind -= 1 text = text[: ind + 1] ind -= 1 cnt = 1 while cnt > 0: if text[ind] == "]": cnt += 1 elif text[ind] == "[": cnt -= 1 ind -= 1 # print("提取的list:", text[ind + 1:]) jj = json.loads(text[ind + 1:]) return jj def get_response(prompt, model="glm-4-0520"): # if model == "llama3-70b-8192": # client = groq_client # else: # client = open_ai_client # client = OpenAI( # api_key="sk-abb441ab88af458d88862ce9be39f746", # base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # ) client = OpenAI( api_key="56e7d49c4052b32956e403216ce927ca.KttmQZhDU7JVvkU1", base_url="https://open.bigmodel.cn/api/paas/v4", ) chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], temperature=0.8, top_p=0.8, model=model, ) res = chat_completion.choices[0].message.content return res def load_state(state_file): with open(state_file, "r") as f: state = json.load(f) return state def save_state(state, dir): with open(dir, "w") as f: json.dump(state, f, indent=4) def shape_string_to_list(shape_string): if type(shape_string) is list: return shape_string # convert a string like "[N, M, K, 19]" to a list like ['N', 'M', 'K', 19] shape_string = shape_string.strip() shape_string = shape_string[1:-1] shape_list = shape_string.split(",") shape_list = [x.strip() for x in shape_list] shape_list = [int(x) if x.isdigit() else x for x in shape_list] if len(shape_list) == 1 and shape_list[0] == "": shape_list = [] return shape_list def extract_equal_sign_closed(text): ind_1 = text.find("=====") ind_2 = text.find("=====", ind_1 + 1) obj = text[ind_1 + 6: ind_2].strip() return obj class Logger: def __init__(self, file): self.file = file def log(self, text): with open(self.file, "a") as f: f.write(text + "\n") def reset(self): with open(self.file, "w") as f: f.write("") def create_state(parent_dir, run_dir): # with open(os.path.join(parent_dir, "data/params.json"), "r") as f: # params = json.load(f) from parameters import get_params with open("data/desc.txt", "r") as f: desc = f.read() params = get_params(desc, check=True) # data = {} # for key in params: # data[key] = params[key]["value"] # del params[key]["value"] # # # save the data file in the run_dir # with open(os.path.join(run_dir, "data.json"), "w") as f: # json.dump(data, f, indent=4) with open(os.path.join(run_dir, "desc.txt"), "r") as f: desc = f.read() state = {"description": desc, "parameters": params} return state