Skip to content

Commit

Permalink
WIP: cli fixed, adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
tavallaie committed May 26, 2024
1 parent e7ee4cf commit a51281e
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 83 deletions.
160 changes: 81 additions & 79 deletions djangowiz/core/project_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# djangowiz/core/project_generator.py

import os
import importlib
import sys
from typing import List, Dict, Any
from jinja2 import Environment, FileSystemLoader, ChoiceLoader
from djangowiz.core.io_handler import IOHandler
from djangowiz.core.project_io_handler import ProjectIOHandler

Expand Down Expand Up @@ -52,15 +54,6 @@ def __init__(
self.generator_dir,
)

self.env = Environment(
loader=ChoiceLoader(
[
FileSystemLoader(self.template_dir),
FileSystemLoader(self.default_template_dir),
]
)
)

self.load_generators(self.config_file)

def load_generators(self, config_file: str):
Expand All @@ -78,39 +71,40 @@ def load_generators(self, config_file: str):
else:
self.generators[name] = generator_config

# Use a copy of the dictionary to avoid runtime error
generators_copy = self.generators.copy()
for name, generator_config in generators_copy.items():
# Iterate over a copy of the items to avoid modifying the dictionary during iteration
for name, generator_config in list(self.generators.items()):
for option, config in generator_config.get("options", {}).items():
self.load_generator(name, option, config)

def load_generator(self, name: str, option: str, config: Dict[str, Any]):
class_path = config["class"]
module_name, class_name = class_path.rsplit(".", 1)
module_path, class_name = class_path.rsplit(".", 1)

# Ensure the generator directory is in the Python path
sys.path.append(self.generator_dir)

print(f"Attempting to load module: {module_path}, class: {class_name}")

print(f"Attempting to load module: {module_name}, class: {class_name}")
try:
module = importlib.import_module(module_name)
module = importlib.import_module(module_path)
generator_class = getattr(module, class_name)
print(f"Successfully loaded {class_path}")
except (ImportError, AttributeError) as e:
template_path = config.get("template", "")

if not os.path.exists(os.path.join(self.template_dir, template_path)):
template_path = os.path.join(self.default_template_dir, template_path)

self.generators[f"{name}_{option}"] = {
"class": generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**config,
),
"template": template_path,
}
except Exception as e:
print(f"Error loading {class_path}: {e}")
return

template_path = config.get("template", "")
if not os.path.exists(os.path.join(self.template_dir, template_path)):
template_path = os.path.join(self.default_template_dir, template_path)

self.generators[f"{name}_{option}"] = {
"class": generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**config,
),
"template": template_path,
}

def save_generators(self):
combined_config = {"generators": {}}
Expand Down Expand Up @@ -146,28 +140,34 @@ def add_generator(
template_path = os.path.join(self.default_template_dir, template_path)

module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
generator_class = getattr(module, class_name)
generator_instance = generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**kwargs,
)
# Ensure the generator directory is in the Python path
sys.path.append(self.generator_dir)

self.generators[generator_key] = {
"class": generator_instance,
"template": template_path,
}
try:
module = importlib.import_module(module_path)
generator_class = getattr(module, class_name)
generator_instance = generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**kwargs,
)

self.save_generators()
IOHandler.copy_file(
module_path.replace(".", "/") + ".py",
os.path.join(self.generator_dir, module_path.replace(".", "/") + ".py"),
)
self.load_generators(self.config_file) # Reload configuration
print(f"Generator '{generator_key}' has been added.")
self.generators[generator_key] = {
"class": generator_instance,
"template": template_path,
}

self.save_generators()
IOHandler.copy_file(
module_path.replace(".", "/") + ".py",
os.path.join(self.generator_dir, module_path.replace(".", "/") + ".py"),
)
self.load_generators(self.config_file) # Reload configuration
print(f"Generator '{generator_key}' has been added.")
except Exception as e:
print(f"Error adding generator {class_path}: {e}")

def delete_generator(self, name: str, option: str):
generator_key = f"{name}_{option}"
Expand All @@ -191,28 +191,34 @@ def update_generator(
template_path = os.path.join(self.default_template_dir, template_path)

module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
generator_class = getattr(module, class_name)
generator_instance = generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**kwargs,
)
# Ensure the generator directory is in the Python path
sys.path.append(self.generator_dir)

self.generators[generator_key] = {
"class": generator_instance,
"template": template_path,
}
try:
module = importlib.import_module(module_path)
generator_class = getattr(module, class_name)
generator_instance = generator_class(
self.app_name,
self.project_name,
self.model_names,
self.template_dir,
**kwargs,
)

self.save_generators()
IOHandler.copy_file(
module_path.replace(".", "/") + ".py",
os.path.join(self.generator_dir, module_path.replace(".", "/") + ".py"),
)
self.load_generators(self.config_file) # Reload configuration
print(f"Generator '{generator_key}' has been updated.")
self.generators[generator_key] = {
"class": generator_instance,
"template": template_path,
}

self.save_generators()
IOHandler.copy_file(
module_path.replace(".", "/") + ".py",
os.path.join(self.generator_dir, module_path.replace(".", "/") + ".py"),
)
self.load_generators(self.config_file) # Reload configuration
print(f"Generator '{generator_key}' has been updated.")
except Exception as e:
print(f"Error updating generator {class_path}: {e}")

def show_generators(self):
for name, generator in self.generators.items():
Expand All @@ -228,11 +234,7 @@ def generate(
if generator_key in self.generators:
generator = self.generators[generator_key]["class"]
template = self.generators[generator_key]["template"]
print(f"Generating {generator_key} using template {template}")
output = generator.generate(
overwrite=overwrite, template=template, **kwargs
)
print(f"Generated output for {generator_key}: {output}")
generator.generate(overwrite=overwrite, template=template, **kwargs)

def export_config(self, export_path: str):
self.io_handler.export_config(export_path)
Expand Down
8 changes: 4 additions & 4 deletions djangowiz/repo/generators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ generators:
serializers:
options:
single:
class: repo.generators.serializer_generator.SerializerGenerator
class: djangowiz.repo.generators.serializer_generator.SerializerGenerator
template: single/serializers.py.j2
multi:
class: repo.generators.serializer_generator.SerializerGenerator
class: djangowiz.repo.generators.serializer_generator.SerializerGenerator
template: multi/serializers.py.j2
viewsets:
options:
single:
class: repo.generators.viewset_generator.ViewsetGenerator
class: djangowiz.repo.generators.viewset_generator.ViewsetGenerator
template: single/viewsets.py.j2
multi:
class: repo.generators.viewset_generator.ViewsetGenerator
class: djangowiz.repo.generators.viewset_generator.ViewsetGenerator
template: multi/viewsets.py.j2
# urls:
# options:
Expand Down
61 changes: 61 additions & 0 deletions djangowiz/repo/templates/test.py/j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# djangowiz/repo/templates/test.py.j2

from django.test import TestCase
from django.urls import reverse
from rest_framework import status
from {{ app_name }}.models import {{ model_name }}
from rest_framework.test import APIClient

class {{ model_name }}APITests(TestCase):
fixtures = ['{{ model_name|lower }}_fixtures.json']

def setUp(self):
self.client = APIClient()
self.model_list_url = reverse('{{ model_name|lower }}-list')
self.model_detail_url = lambda pk: reverse('{{ model_name|lower }}-detail', args=[pk])
self.instances = {{ model_name }}.objects.all()

def test_list_{{ model_name|lower }}s(self):
response = self.client.get(self.model_list_url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), len(self.instances))

def test_create_{{ model_name|lower }}(self):
data = {'name': 'New Product'}
response = self.client.post(self.model_list_url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual({{ model_name }}.objects.count(), len(self.instances) + 1)

def test_retrieve_{{ model_name|lower }}(self):
instance = self.instances.first()
response = self.client.get(self.model_detail_url(instance.pk))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['name'], instance.name)

def test_update_{{ model_name|lower }}(self):
instance = self.instances.first()
data = {'name': 'Updated Product'}
response = self.client.put(self.model_detail_url(instance.pk), data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
instance.refresh_from_db()
self.assertEqual(instance.name, 'Updated Product')

def test_partial_update_{{ model_name|lower }}(self):
instance = self.instances.first()
data = {'name': 'Partially Updated Product'}
response = self.client.patch(self.model_detail_url(instance.pk), data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
instance.refresh_from_db()
self.assertEqual(instance.name, 'Partially Updated Product')

def test_delete_{{ model_name|lower }}(self):
instance = self.instances.first()
response = self.client.delete(self.model_detail_url(instance.pk))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual({{ model_name }}.objects.count(), len(self.instances) - 1)

def test_invalid_create(self):
# Missing 'name' field in data
data = {}
response = self.client.post(self.model_list_url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
File renamed without changes.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ djangowiz = "djangowiz.cli:app"
python = "^3.10"
typer = {extras = ["all"], version = "^0.12.3"}
jinja2 = "^3.1.4"
pyyaml = "^6.0.1"


[build-system]
Expand Down

0 comments on commit a51281e

Please sign in to comment.