building-agents/llma/utils.py

155 lines
3.9 KiB
Python

import os
import json
from groq import Groq
from openai import OpenAI
# groq_key = "###"
openai_key = ""
# openai_org = "###"
# groq_client = Groq(api_key=groq_key)
# open_ai_client = openai.Client(api_key=openai_key, organization=openai_org)
def extract_json_from_end(text):
if "```json" in text:
text = text.split("```json")[1]
text = text.split("```")[0]
ind = len(text) - 1
while text[ind] != "}":
ind -= 1
text = text[: ind + 1]
ind -= 1
cnt = 1
while cnt > 0:
if text[ind] == "}":
cnt += 1
elif text[ind] == "{":
cnt -= 1
ind -= 1
# find comments in the json string (texts between "//" and "\n") and remove them
while True:
ind_comment = text.find("//")
if ind_comment == -1:
break
ind_end = text.find("\n", ind_comment)
text = text[:ind_comment] + text[ind_end + 1:]
# print("提取的json:", text[ind + 1:])
jj = json.loads(text[ind + 1:])
return jj
def extract_list_from_end(text):
ind = len(text) - 1
while text[ind] != "]":
ind -= 1
text = text[: ind + 1]
ind -= 1
cnt = 1
while cnt > 0:
if text[ind] == "]":
cnt += 1
elif text[ind] == "[":
cnt -= 1
ind -= 1
# print("提取的list:", text[ind + 1:])
jj = json.loads(text[ind + 1:])
return jj
def get_response(prompt, model="glm-4-0520"):
# if model == "llama3-70b-8192":
# client = groq_client
# else:
# client = open_ai_client
# client = OpenAI(
# api_key="sk-abb441ab88af458d88862ce9be39f746",
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
# )
client = OpenAI(
api_key="56e7d49c4052b32956e403216ce927ca.KttmQZhDU7JVvkU1",
base_url="https://open.bigmodel.cn/api/paas/v4",
)
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
top_p=0.8,
model=model,
)
res = chat_completion.choices[0].message.content
return res
def load_state(state_file):
with open(state_file, "r") as f:
state = json.load(f)
return state
def save_state(state, dir):
with open(dir, "w") as f:
json.dump(state, f, indent=4)
def shape_string_to_list(shape_string):
if type(shape_string) is list:
return shape_string
# convert a string like "[N, M, K, 19]" to a list like ['N', 'M', 'K', 19]
shape_string = shape_string.strip()
shape_string = shape_string[1:-1]
shape_list = shape_string.split(",")
shape_list = [x.strip() for x in shape_list]
shape_list = [int(x) if x.isdigit() else x for x in shape_list]
if len(shape_list) == 1 and shape_list[0] == "":
shape_list = []
return shape_list
def extract_equal_sign_closed(text):
ind_1 = text.find("=====")
ind_2 = text.find("=====", ind_1 + 1)
obj = text[ind_1 + 6: ind_2].strip()
return obj
class Logger:
def __init__(self, file):
self.file = file
def log(self, text):
with open(self.file, "a") as f:
f.write(text + "\n")
def reset(self):
with open(self.file, "w") as f:
f.write("")
def create_state(parent_dir, run_dir):
# with open(os.path.join(parent_dir, "data/params.json"), "r") as f:
# params = json.load(f)
from parameters import get_params
with open("data/desc.txt", "r") as f:
desc = f.read()
params = get_params(desc, check=True)
# data = {}
# for key in params:
# data[key] = params[key]["value"]
# del params[key]["value"]
#
# # save the data file in the run_dir
# with open(os.path.join(run_dir, "data.json"), "w") as f:
# json.dump(data, f, indent=4)
with open(os.path.join(run_dir, "desc.txt"), "r") as f:
desc = f.read()
state = {"description": desc, "parameters": params}
return state