4.2 KiB
4.2 KiB
In [ ]:
import sys from pathlib import Path root_path = Path(".") while not (root_path/".git").exists(): root_path = root_path.absolute().parent sys.path.append(str(root_path/"data"/"rag")) from rag_utils import constraint_path, problem_descriptions_vector_db_path, constraint_vector_db_path, objective_descriptions_vector_db_path
In [ ]:
import pandas as pd from typing import Dict, List from langchain_chroma import Chroma from langchain.schema.document import Document from langchain_openai import OpenAIEmbeddings
In [ ]:
constraint_df = pd.read_pickle(constraint_path) constraint_df
In [ ]:
constraint_df.columns
In [ ]:
unique_problems_df = constraint_df[["description", "problem_name"]].drop_duplicates() unique_problems_df
In [ ]:
def make_vector_db(data: Dict[str, str], vector_db_path: Path, model_name: str = "text-embedding-3-large"): """ Creates a vector database from a dictionary of strings. Args: data (Dict[str, str]): A dictionary where keys are identifiers and values are strings. vector_db_path (Path): The path to save the vector database. model_name (str): The model name for generating embeddings. """ embedding_function = OpenAIEmbeddings(model=model_name) docs = [Document(page_content=value, metadata={"key": key}) for key, value in data.items()] if vector_db_path.exists(): vector_db_path.unlink() vector_db_path.mkdir(exist_ok=True, parents=True) Chroma.from_documents(docs, embedding_function, persist_directory=str(vector_db_path))
In [ ]:
unique_problems_dict = unique_problems_df.set_index('problem_name')['description'].to_dict() unique_problems_dict make_vector_db(unique_problems_dict, problem_descriptions_vector_db_path)
In [ ]:
constraint_dict = constraint_df.constraint_description.to_dict() constraint_dict
In [ ]:
make_vector_db(constraint_dict, constraint_vector_db_path)
In [ ]:
objectives_dict = constraint_df[['objective_description', 'problem_name']].drop_duplicates().set_index('problem_name')['objective_description'].to_dict() objectives_dict
In [ ]:
make_vector_db(objectives_dict, objective_descriptions_vector_db_path)