Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] query with gpt #78

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.dailyon.productservice.brand.repository.BrandRepository;
import com.dailyon.productservice.category.dto.response.ReadChildrenCategoryListResponse;
import com.dailyon.productservice.category.dto.response.ReadChildrenCategoryResponse;
import com.dailyon.productservice.category.entity.Category;
import com.dailyon.productservice.category.repository.CategoryRepository;
import com.dailyon.productservice.common.enums.Gender;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -38,8 +39,13 @@ public String getSearchResults(String searchQuery) throws Exception {
.map(ReadBrandResponse::toString)
.collect(Collectors.joining(","));

List<Category> midCategories = categoryRepository.findAllChildCategories(null)
.stream()
.flatMap(rootCategory -> categoryRepository.findAllChildCategories(rootCategory.getId()).stream())
.collect(Collectors.toList());

List<ReadChildrenCategoryResponse> categories = ReadChildrenCategoryListResponse
.fromEntity(categoryRepository.findLeafCategories())
.fromEntity(midCategories)
.getCategoryResponses();

String allCategories = categories.stream()
Expand Down Expand Up @@ -74,25 +80,55 @@ public String getSearchResults(String searchQuery) throws Exception {
return objectMapper.writeValueAsString(result);
}

public String getTranslatedPrompt(String searchQuery) throws Exception {
List<Map<String, Object>> messages = new ArrayList<>();
Map<String, Object> message = new HashMap<>();
message.put("role", "user");
message.put("content", createTranslatePrompt(searchQuery));
messages.add(message);

Map<String, Object> requestData = new HashMap<>();
requestData.put("model", "gpt-3.5-turbo-1106");
requestData.put("messages", messages);
requestData.put("temperature", 0.2);
requestData.put("max_tokens", 300);

HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(environment.getProperty("open-ai.secret-key"));

String requestBody = objectMapper.writeValueAsString(requestData);

HttpEntity<String> entity = new HttpEntity<>(requestBody, headers);

String apiEndpoint = "https://api.openai.com/v1/chat/completions";
Object result = restTemplate.postForObject(apiEndpoint, entity, Object.class);
return objectMapper.writeValueAsString(result);
}

private String createTranslatePrompt(String query) {
return "Translate this sentence into english: " + "\"" + query + "\"";
}

private String createPrompt(String searchQuery, String brands, String categories, String genders) {
return "{" +
"\"categories\": [" + categories + "], " +
"\"brands\": [" + brands + "], " +
"\"genders\": " + genders + ", " +
"\"priceRanges\": [" +
"{\"id\": 1, \"name\": \"$0-$99\"}, " +
"{\"id\": 2, \"name\": \"$100-$199\"}, " +
"{\"id\": 3, \"name\": \"$200-$299\"}, " +
"{\"id\": 4, \"name\": \"$300-$399\"}, " +
"{\"id\": 5, \"name\": \"over $400\"}" +
"{\"id\": 1, \"name\": \"$0-$299\"}, " +
"{\"id\": 2, \"name\": \"$300-$599\"}, " +
"{\"id\": 3, \"name\": \"$600-$899\"}, " +
"{\"id\": 4, \"name\": \"$900-$1199\"}, " +
"{\"id\": 5, \"name\": \"over $1200\"}" +
"], " +
"Search Query: \"" + searchQuery + "\"." +
"Based on the search query, let me know the relevant 3 categories, 3 brands, 1 gender, and 1 price range. " +
"Please provide the answer in the json object format. " +
"Based on the search query, let me know the relevant maximum 3 categories, maximum 5 brands, 1 gender, and 1 price range. " +
"Please provide the answer in the json object format. And also Field priceRanges can be null." +
"{\"categories\":[{\"id\":1, \"name\":\"Fashion\"}, {\"id\":2, \"name\":\"Electronics\"}, {\"id\":3, \"name\": \"Home & Living\"}], " +
"\"brands\":[{\"id\":1, \"name\":\"Nike\"}, {\"id\":2, \"name\":\"Samsung\"}, {\"id\":3, \"name\":\"Apple\"}], " +
"\"genders\":[\"MALE\"], " +
"\"priceRanges\":[{\"id\":1, \"name\":\"$0-$99\"}]}" +
"\"priceRanges\":[{\"id\":1, \"name\":\"$0-$199\"}]}" +
"}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ public static ReadBrandResponse fromEntity(Brand brand) {
public static class PriceRange {
private Long id;
private String name;

public Integer[] parseHighAndLow() {
String[] split = this.name.replace("$", "").split("-");
return new Integer[] { Integer.parseInt(split[0]) * 1000, Integer.parseInt(split[1]) * 1000 };
}
}

@Getter
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dailyon.productservice.product.facade;

import com.dailyon.productservice.category.entity.Category;
import com.dailyon.productservice.common.enums.Gender;
import com.dailyon.productservice.common.enums.ProductType;
import com.dailyon.productservice.common.exception.DeleteException;
Expand Down Expand Up @@ -99,12 +100,15 @@ public ReadProductSearchResponse searchProducts(String query) {

if (products.isEmpty()) {
try {
String response = openAIClient.getSearchResults(query);
String translatedResponse = openAIClient.getTranslatedPrompt(query);
OpenAIResponse translatedContent = gson.fromJson(translatedResponse, OpenAIResponse.class);
String translatedQuery = translatedContent.getChoices().get(0).getMessage().getContent();
log.info("translatedQuery: "+translatedQuery);

String response = openAIClient.getSearchResults(translatedQuery);
OpenAIResponse responseFromGpt = gson.fromJson(response, OpenAIResponse.class);
String jsonContent = responseFromGpt.getChoices().get(0).getMessage().getContent();
log.info("================================");
log.info(jsonContent);
log.info("================================");

OpenAIResponse.Content content = gson.fromJson(jsonContent, OpenAIResponse.Content.class);

// Use content object to search products
Expand All @@ -118,7 +122,12 @@ public ReadProductSearchResponse searchProducts(String query) {

Gender gender = content.getGenders().get(0);

products = productService.searchAfterGpt(categoryIds, brandIds, gender);
Integer[] prices = new Integer[2];
if(content.getPriceRanges().get(0) != null) {
prices = content.getPriceRanges().get(0).parseHighAndLow();
}

products = productService.searchAfterGpt(categoryIds, brandIds, gender, prices[0], prices[1]);
} catch (Exception e) {
// Properly log and handle the exception as per your application's requirements
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ Page<Product> findProductPage(

List<Product> searchProducts(String query);

List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds, Gender gender);
List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds,
Gender gender, Integer lowPrice, Integer highPrice);
}
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ public List<Product> searchProducts(String query) {
}

@Override
public List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds, Gender gender) {
public List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds,
Gender gender, Integer lowPrice, Integer highPrice) {
return jpaQueryFactory.selectDistinct(product)
.from(product)
.leftJoin(product.brand, brand).fetchJoin()
Expand All @@ -196,6 +197,7 @@ public List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds,
.and(product.category.id.in(categoryIds))
.and(product.type.eq(ProductType.NORMAL))
.and(product.gender.eq(gender))
.and(filterPrice(lowPrice, highPrice))
)
.orderBy(orderSpecifier("createdAt", "desc"))
.fetch();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,18 @@ public List<Product> searchProducts(String query) {
return productRepository.searchProducts(query);
}

public List<Product> searchAfterGpt(List<Long> brandIds, List<Long> categoryIds, Gender gender) {
return productRepository.searchAfterGpt(brandIds, categoryIds, gender);
public List<Product> searchAfterGpt(List<Long> categoryIds, List<Long> brandIds,
Gender gender, Integer lowPrice, Integer highPrice) {

List<Long> childCategoryIds = new ArrayList<>();
for(Long categoryId: categoryIds) {
childCategoryIds.addAll(categoryRepository.findAllChildCategories(categoryId)
.stream()
.map(Category::getId)
.collect(Collectors.toList()));
}

return productRepository.searchAfterGpt(childCategoryIds, brandIds, gender, lowPrice, highPrice);
}

public ReadOOTDSearchSliceResponse searchFromOOTD(Long lastId, String query) {
Expand Down
Loading