Skip to content

Commit

Permalink
Reworked pytorch example.
Browse files Browse the repository at this point in the history
  • Loading branch information
WyvernIXTL committed May 8, 2024
1 parent 3cd7edd commit ffb871e
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 146 deletions.
5 changes: 0 additions & 5 deletions examples/pytorch/bad_user_input.yml

This file was deleted.

104 changes: 0 additions & 104 deletions examples/pytorch/instruction_pytorch.md.jinja

This file was deleted.

51 changes: 28 additions & 23 deletions examples/pytorch/instruction_pytorch.schema.yml
Original file line number Diff line number Diff line change
@@ -1,38 +1,43 @@
$schema: https://json-schema.org/draft/2020-12/schema
$id: https://github.com/instructions-d-installation/installation-instruction/examples/pytorch/schema_pytorch.yml
$id: https://github.com/instructions-d-installation/installation-instruction/examples/pytorch/instruction_pytorch.schema.yml
title: PyTorch Install Schema
description: This is a schema which is used for constructing interactive installation instructions.
type: object
$comment: by Adam McKellar
properties:
build:
enum:
- Stable (2.3.0)
- Preview (Nightly)
anyOf:
- title: Stable (2.3.0)
const: stable
- title: Preview (Nightly)
const: preview
os:
enum:
- Linux
- Mac
- Windows
anyOf:
- title: Linux
const: linux
- title: Mac
const: mac
- title: Windows
const: win
package:
enum:
- Conda
- Pip
- LibTorch
- Source
language:
enum:
- Python
- C++/Java
anyOf:
- title: Conda
const: conda
- title: Pip
const: pip
computer_platform:
enum:
- CUDA 11.8
- CUDA 12.1
- ROCm 6.0
- CPU
anyOf:
- title: CUDA 11.8
const: cu118
- title: CUDA 12.1
const: cu121
- title: ROCm 6.0
const: ro60
- title: CPU
const: cpu
required:
- build
- os
- package
- language
- computer_platform
additionalProperties: false
61 changes: 61 additions & 0 deletions examples/pytorch/instruction_pytorch.txt.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{# Adams spaghetti code #}
{% if package == "conda" %}
conda install

{% if os == "mac" %}
{% if computer_platform == "cpu" %}
pytorch::pytorch torchvision torchaudio
{% else %}
[[ERROR]]
Mac does not support ROCm or CUDA!
[[ERROR]]
{% endif %}
{% else %}
pytorch torchvision torchaudio

{% if computer_platform == "cu118" %}
pytorch-cuda=11.8 -c nvidia
{% elif computer_platform == "cu121" %}
pytorch-cuda=12.1 -c nvidia
{% elif computer_platform == "ro60" %}
[[ERROR]]
ROCm is currently not supported with conda on linux and not supported at all on windows!
[[ERROR]]
{% else %}
cpuonly
{% endif %}
{% endif %}

-c pytorch

{% elif package == "pip" %}
pip3 install torch torchvision torchaudio

{% if os == "mac" %}
{% if computer_platform != "cpu" %}
[[ERROR]]
Mac does not support ROCm or CUDA!
[[ERROR]]
{% endif %}
{% else %}
{% if computer_platform == "cu118" %}
--index-url https://download.pytorch.org/whl/cu118
{% elif computer_platform == "cu121" %}
{% if os == "win" %}
--index-url https://download.pytorch.org/whl/cu121
{% endif %}
{% elif computer_platform == "ro60" %}
{% if os == "linux" %}
--index-url https://download.pytorch.org/whl/rocm6.0
{% else %}
[[ERROR]]
Windows does not support ROCm!
[[ERROR]]
{% endif %}
{% else %}
{% if os == "linux" %}
--index-url https://download.pytorch.org/whl/cpu
{% endif %}
{% endif %}
{% endif %}
{% endif %}
9 changes: 4 additions & 5 deletions examples/pytorch/potential_user_input.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
build: Stable (2.3.0)
os: Linux
package: Conda
language: Python
computer_platform: CPU
build: stable
os: linux
package: conda
computer_platform: cpu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jinja2 import Environment, Template

import re


def load_yml(path: str) -> dict:
Expand All @@ -26,26 +27,32 @@ def load_template(path: str) -> Template:

return env.from_string(template_str)

def get_error_message(parsed_template) -> bool:
reg = re.compile(".*\[\[ERROR\]\]\s*(?P<errmsg>.*?)\s*\[\[ERROR\]\].*", re.S)
matches = reg.search(parsed_template)
if matches is None:
return None
return matches.group("errmsg")

def replace_blank_space(string: str) -> str:
return re.sub("\s{1,}", " ", string, 0, re.S).strip()


schema = load_yml('instruction_pytorch.schema.yml')
input = load_yml('potential_user_input.yml')
#bad_input = load_yml('bad_user_input.yml')



print("Test valid input.")
validate(input, schema)
print("It worked!")


# print("Test invalid input.")
# validate(bad_input, schema)
# print("It worked!")


template = load_template("instruction_pytorch.md.jinja")
template = load_template("instruction_pytorch.txt.jinja")

instructions = template.render(input)
print(instructions)
instructions = replace_blank_space(instructions)

if errmsg := get_error_message(instructions):
print(errmsg)
else:
print(instructions)

0 comments on commit ffb871e

Please sign in to comment.