Skip to content

Commit

Permalink
Update ideogram model version and add auth refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
brainboost committed Mar 24, 2024
2 parents 4c4c8f9 + 4f53ff3 commit da01097
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 35 deletions.
5 changes: 5 additions & 0 deletions engines/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def read_ssm_param(param_name: str) -> str:
return ssm_client.get_parameter(Name=param_name)["Parameter"]["Value"]


def write_ssm_param(param_name: str, value: str) -> str:
ssm_client = boto3.client(service_name="ssm")
return ssm_client.get_parameter(Name=param_name)["Parameter"]["Value"]


def read_json_from_s3(bucket_name: str, file_name: str) -> Optional[Any]:
s3 = boto3.client("s3")
response = s3.get_object(Bucket=bucket_name, Key=file_name)
Expand Down
75 changes: 54 additions & 21 deletions engines/ideogram_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,29 @@
logging.basicConfig()
logging.getLogger().setLevel("INFO")

ig_cookies = "ig-cookies.json"
base_url = "https://ideogram.ai"
browser_version = "chrome110"
user_agent = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/117.0"
)
browser_version = "chrome120"
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
id_key = "AIzaSyBwq4bRiOapXYaKE-0Y46vLAw1-fzALq7Y"
tokens_file = "google_auth.json"
post_task_url = f"{base_url}/api/images/sample"

sqs = boto3.session.Session().client("sqs")
token = read_ssm_param(param_name="IDEOGRAM_TOKEN")
bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
user_id = read_ssm_param(param_name="IDEOGRAM_USER")
# channel_id = read_ssm_param(param_name="IDEOGRAM_CHANNEL")
headers = {
"Cookie": f"session_cookie={token};",
"Origin": base_url,
"Referer": base_url + "/",
"DNT": "1",
"Accept-Encoding": "gzip, deflate, br",
"Content-Type": "application/json",
"Pragma": "no-cache",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"TE": "trailers",
"User-Agent": user_agent,
# "Authorization": f"Bearer {channel_id}",
}
bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
ideogram_result_queue = sqs.get_queue_url(QueueName="Ideogram-Result-Queue")["QueueUrl"]


Expand All @@ -59,28 +55,50 @@ def is_expired(id_token: str) -> bool:
return False


def refresh_tokens(refresh_token: str) -> dict:
def refresh_iss_tokens(refresh_token: str) -> dict:
request_ref = "https://securetoken.googleapis.com/v1/token?key=" + id_key
headers = {
"Accept": "*/*",
"Content-Type": "application/json; charset=UTF-8",
"X-Client-Version": "Firefox/JsCore/9.23.0/FirebaseCore-web",
"User-Agent": user_agent,
"Origin": "https://ideogram.ai",
"Origin": base_url,
}
data = json.dumps({"grantType": "refresh_token", "refreshToken": refresh_token})
response_object = requests.post(request_ref, headers=headers, data=data)
response_object = requests.post(
request_ref,
headers=headers,
data=data,
impersonate=browser_version,
)
response_object_json = response_object.json()
tokens = {
"user_id": response_object_json["user_id"],
"id_token": response_object_json["id_token"],
"access_token": response_object_json["access_token"],
"refresh_token": response_object_json["refresh_token"],
}
save_to_s3(bucket_name=bucket_name, file_name=tokens_file, value=tokens)
return tokens


def check_and_refresh_google_token() -> str:
def get_session_cookies(iss_token: str) -> dict:
request_url = f"{base_url}/api/account/login"
headers["Authorization"] = f"Bearer {iss_token}"
response_obj = requests.post(
url=request_url,
headers=headers,
data=json.dumps({}),
auth=("Bearer", iss_token),
)
if not response_obj.ok:
logging.error(response_obj.text)
raise Exception(f"Error response {str(response_obj)}")
cookies = dict(response_obj.cookies)
save_to_s3(bucket_name=bucket_name, file_name=ig_cookies, value=cookies)
return cookies


def check_and_refresh_auth_tokens() -> dict:
tokens = read_json_from_s3(bucket_name=bucket_name, file_name=tokens_file)
if not tokens:
error = f"Cannot read file '{tokens_file}' from the S3 bucket '{bucket_name}'. Put json with the field 'refresh_token' and save"
Expand All @@ -90,10 +108,18 @@ def check_and_refresh_google_token() -> str:
if not refresh_token:
logging.error(f"No 'refresh_token' found in the {tokens_file}")
return None
id_token = tokens.get("id_token", None)
if not id_token or is_expired(id_token):
tokens = refresh_tokens(refresh_token=refresh_token)
return tokens["id_token"]
acc_token = tokens.get("access_token", None)
if not acc_token or is_expired(acc_token):
tokens = refresh_iss_tokens(refresh_token=refresh_token)
return tokens


def __cookies_to_header_string(cookies: dict) -> str:
cookie_pairs = []
for key, value in cookies.items():
if key == "session_cookie":
cookie_pairs.append(f"{key}={value}")
return "; ".join(cookie_pairs)


def request_images(prompt: str) -> str:
Expand All @@ -109,14 +135,21 @@ def request_images(prompt: str) -> str:
"variation_strength": 50,
}
logging.info(payload)
bearer = check_and_refresh_google_token()
headers["Authorization"] = f"Bearer {bearer}"
tokens = check_and_refresh_auth_tokens()
try:
cookies = read_json_from_s3(bucket_name=bucket_name, file_name=ig_cookies)
except Exception:
logging.info(f"Cannot find {ig_cookies} in s3 bucket {bucket_name}")
cookies = None
if not cookies or is_expired(cookies["session_cookie"]):
cookies = get_session_cookies(iss_token=tokens["access_token"])
headers["Cookie"] = __cookies_to_header_string(dict(cookies))
headers["Authorization"] = f"Bearer {tokens['access_token']}"
response = requests.post(
url=post_task_url,
headers=headers,
data=json.dumps(payload),
impersonate=browser_version,
auth=("Bearer", bearer),
)
if not response.ok:
logging.error(response.text)
Expand Down
7 changes: 4 additions & 3 deletions engines/ideogram_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

engine_type = "ideogram"
threshold_img_quality = 1024
browser_version = "chrome110"
retrieve_metadata_url = "https://ideogram.ai/api/images/retrieve_metadata_request_id/"
get_images_url = "https://ideogram.ai/api/images/direct/"
base_url = "https://ideogram.ai"
browser_version = "chrome120"
retrieve_metadata_url = f"{base_url}/api/images/retrieve_metadata_request_id/"
get_images_url = f"{base_url}/api/images/direct/"

result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")
Expand Down
36 changes: 25 additions & 11 deletions tests/test_ideogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,44 @@

from engines.common_utils import read_json_from_s3, read_ssm_param
from engines.ideogram_img import (
check_and_refresh_google_token,
check_and_refresh_auth_tokens,
get_session_cookies,
is_expired,
refresh_tokens,
refresh_iss_tokens,
request_images,
)


@pytest.mark.skip()
# @pytest.mark.skip()
def test_check_and_refresh(capsys):
with capsys.disabled():
check_and_refresh_google_token()
tokens = check_and_refresh_auth_tokens()
assert "access_token" in tokens
assert "refresh_token" in tokens


@pytest.mark.skip()
# @pytest.mark.skip()
def test_refresh(capsys):
with capsys.disabled():
bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
token = read_json_from_s3(bucket_name=bucket_name, file_name="google_auth.json")
data = refresh_tokens(token["refresh_token"])
assert not is_expired(data["id_token"])
data = refresh_iss_tokens(token["refresh_token"])
assert not is_expired(data["access_token"])


@pytest.mark.skip()
def test_authorize(capsys):
# @pytest.mark.skip()
def test_request_images(capsys):
with capsys.disabled():
response = request_images(prompt="cute kittens playing with yarn ball")
print(response)
response = request_images(
prompt="cute kittens playfully engaging with a colorful yarn ball"
)
assert response


# @pytest.mark.skip()
def test_get_session_cookies(capsys):
with capsys.disabled():
bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
token = read_json_from_s3(bucket_name=bucket_name, file_name="google_auth.json")
response = get_session_cookies(iss_token=token["access_token"])
assert response

0 comments on commit da01097

Please sign in to comment.