diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..cb9509a --- /dev/null +++ b/Package.resolved @@ -0,0 +1,14 @@ +{ + "pins" : [ + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", + "version" : "1.1.4" + } + } + ], + "version" : 2 +} diff --git a/Package.swift b/Package.swift index e7074aa..80ada54 100644 --- a/Package.swift +++ b/Package.swift @@ -13,17 +13,25 @@ let package = Package( targets: ["Jinja"] ) ], + dependencies: [ + .package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4") + ], targets: [ // Targets are the basic building blocks of a package, defining a module or a test suite. // Targets can depend on other targets in this package and products from dependencies. .target( name: "Jinja", + dependencies: [ + .product(name: "OrderedCollections", package: "swift-collections") + ], path: "Sources", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), .testTarget( name: "JinjaTests", - dependencies: ["Jinja"], + dependencies: [ + "Jinja" + ], path: "Tests", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), diff --git a/Sources/Ast.swift b/Sources/Ast.swift index 7460284..8263bc7 100644 --- a/Sources/Ast.swift +++ b/Sources/Ast.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections protocol Statement {} @@ -41,7 +42,7 @@ struct TupleLiteral: Literal { } struct ObjectLiteral: Literal { - var value: [(Expression, Expression)] + var value: OrderedDictionary } struct Set: Statement { @@ -49,7 +50,7 @@ struct Set: Statement { var value: Expression } -struct If: Statement { +struct If: Statement, Expression { var test: Expression var body: [Statement] var alternate: [Statement] @@ -59,14 +60,14 @@ struct Identifier: Expression { var value: String } -protocol Loopvar {} -extension Identifier: Loopvar {} -extension TupleLiteral: Loopvar {} +typealias Loopvar = Expression struct For: Statement { var loopvar: Loopvar var iterable: Expression var body: [Statement] + var defaultBlock: [Statement] + var ifCondition: Expression? } struct MemberExpression: Expression { @@ -124,3 +125,23 @@ struct KeywordArgumentExpression: Expression { struct NullLiteral: Literal { var value: Any? = nil } + +struct SelectExpression: Expression { + var iterable: Expression + var test: Expression +} + +struct Macro: Statement { + var name: Identifier + var args: [Expression] + var body: [Statement] +} + +struct KeywordArgumentsValue: RuntimeValue { + var value: [String: any RuntimeValue] + var builtins: [String: any RuntimeValue] = [:] + + func bool() -> Bool { + !value.isEmpty + } +} diff --git a/Sources/Environment.swift b/Sources/Environment.swift index c845068..9689c9f 100644 --- a/Sources/Environment.swift +++ b/Sources/Environment.swift @@ -6,48 +6,46 @@ // import Foundation +import OrderedCollections class Environment { var parent: Environment? var variables: [String: any RuntimeValue] = [ "namespace": FunctionValue(value: { args, _ in - if args.count == 0 { + if args.isEmpty { return ObjectValue(value: [:]) } - if args.count != 1 || !(args[0] is ObjectValue) { + guard args.count == 1, let objectArg = args[0] as? ObjectValue else { throw JinjaError.runtime("`namespace` expects either zero arguments or a single object argument") } - return args[0] + return objectArg }) ] var tests: [String: (any RuntimeValue...) throws -> Bool] = [ - "boolean": { - args in - args[0] is BooleanValue + "boolean": { args in + return args[0] is BooleanValue }, - "callable": { - args in - args[0] is FunctionValue + "callable": { args in + return args[0] is FunctionValue }, - "odd": { - args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 != 0 + "odd": { args in + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 != 0 } else { - throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of: args.first))") } }, "even": { args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 == 0 + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 == 0 } else { - throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of: args.first))") } }, "false": { args in @@ -62,24 +60,28 @@ class Environment { } return false }, + "string": { args in + return args[0] is StringValue + }, "number": { args in - args[0] is NumericValue + return args[0] is NumericValue }, "integer": { args in if let arg = args[0] as? NumericValue { return arg.value is Int } - return false }, + "mapping": { args in + return args[0] is ObjectValue + }, "iterable": { args in - args[0] is ArrayValue || args[0] is StringValue + return args[0] is ArrayValue || args[0] is StringValue || args[0] is ObjectValue }, "lower": { args in if let arg = args[0] as? StringValue { return arg.value == arg.value.lowercased() } - return false }, "upper": { args in @@ -89,16 +91,47 @@ class Environment { return false }, "none": { args in - args[0] is NullValue + return args[0] is NullValue }, "defined": { args in - !(args[0] is UndefinedValue) + return !(args[0] is UndefinedValue) }, "undefined": { args in - args[0] is UndefinedValue + return args[0] is UndefinedValue }, - "equalto": { _ in - throw JinjaError.syntaxNotSupported("equalto") + "equalto": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } + }, + "eq": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } }, ] @@ -107,66 +140,120 @@ class Environment { } func isFunction(_ value: Any, functionType: T.Type) -> Bool { - value is T + return value is T } - func convertToRuntimeValues(input: Any) throws -> any RuntimeValue { + func convertToRuntimeValues(input: Any?) throws -> any RuntimeValue { + // Handle already converted RuntimeValues + if let runtimeValue = input as? any RuntimeValue { + return runtimeValue + } + + // Handle nil values explicitly + if input == nil { + return NullValue() + } + + // Handle nil values + if case Optional.none = input { + return NullValue() + } + + // Helper function to handle any OrderedDictionary type + func convertOrderedDictionary(_ dict: OrderedDictionary) throws -> ObjectValue { + var object: [String: any RuntimeValue] = [:] + var keyOrder: [String] = [] + + for (key, value) in dict { + // Crucial: Convert Optional to T, using NullValue if nil + let convertedValue = (value as Any?) ?? NullValue() + object[key] = try self.convertToRuntimeValues(input: convertedValue) + keyOrder.append(key) + } + return ObjectValue(value: object, keyOrder: keyOrder) + } + switch input { case let value as Bool: return BooleanValue(value: value) - case let values as [any Numeric]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) - } - return ArrayValue(value: items) - case let value as any Numeric: + case let value as Int: + return NumericValue(value: value) + case let value as Double: + return NumericValue(value: value) + case let value as Float: return NumericValue(value: value) case let value as String: return StringValue(value: value) + case let data as Data: + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert data to string") + } + return StringValue(value: string) case let fn as (String) throws -> Void: return FunctionValue { args, _ in - var arg = "" - switch args[0].value { - case let value as String: - arg = value - case let value as Bool: - arg = String(value) - default: - throw JinjaError.runtime("Unknown arg type:\(type(of: args[0].value))") + guard let stringArg = args[0] as? StringValue else { + throw JinjaError.runtime("Argument must be a StringValue") } - - try fn(arg) + try fn(stringArg.value) return NullValue() } case let fn as (Bool) throws -> Void: return FunctionValue { args, _ in - try fn(args[0].value as! Bool) + guard let boolArg = args[0] as? BooleanValue else { + throw JinjaError.runtime("Argument must be a BooleanValue") + } + try fn(boolArg.value) return NullValue() } case let fn as (Int, Int?, Int) -> [Int]: return FunctionValue { args, _ in - let result = fn(args[0].value as! Int, args[1].value as? Int, args[2].value as! Int) - return try self.convertToRuntimeValues(input: result) - } - case let values as [Any]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) + guard args.count > 0, let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + var int1: Int? = nil + if args.count > 1 { + if let numericValue = args[1] as? NumericValue, let tempInt1 = numericValue.value as? Int { + int1 = tempInt1 + } else { + throw JinjaError.runtime("Second argument must be an Int or nil") + } + } + var int2: Int = 1 + if args.count > 2 { + if let numericValue = args[2] as? NumericValue, let tempInt2 = numericValue.value as? Int { + int2 = tempInt2 + } else { + throw JinjaError.runtime("Third argument must be an Int") + } + } + let result = fn(int0, int1, int2) + return ArrayValue(value: result.map { NumericValue(value: $0) }) } + case let values as [Any?]: + let items = try values.map { try self.convertToRuntimeValues(input: $0) } return ArrayValue(value: items) - case let dictionary as [String: String]: + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let dictionary as [String: Any?]: var object: [String: any RuntimeValue] = [:] - + var keyOrder: [String] = [] for (key, value) in dictionary { - object[key] = StringValue(value: value) + object[key] = try self.convertToRuntimeValues(input: value) + keyOrder.append(key) } - - return ObjectValue(value: object) - case is NullValue: - return NullValue() + return ObjectValue(value: object, keyOrder: keyOrder) default: - throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))") + throw JinjaError.runtime( + "Cannot convert to runtime value: \(String(describing: input)) type:\(type(of: input))" + ) } } @@ -176,12 +263,11 @@ class Environment { } func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { - if self.variables.contains(where: { $0.0 == name }) { + if self.variables.keys.contains(name) { throw JinjaError.syntax("Variable already declared: \(name)") } self.variables[name] = value - return value } @@ -191,13 +277,13 @@ class Environment { return value } - func resolve(name: String) throws -> Self { - if self.variables.contains(where: { $0.0 == name }) { + func resolve(name: String) throws -> Environment { + if self.variables.keys.contains(name) { return self } - if let parent { - return try parent.resolve(name: name) as! Self + if let parent = self.parent { + return try parent.resolve(name: name) } throw JinjaError.runtime("Unknown variable: \(name)") @@ -205,11 +291,7 @@ class Environment { func lookupVariable(name: String) -> any RuntimeValue { do { - if let value = try self.resolve(name: name).variables[name] { - return value - } else { - return UndefinedValue() - } + return try self.resolve(name: name).variables[name] ?? UndefinedValue() } catch { return UndefinedValue() } diff --git a/Sources/Lexer.swift b/Sources/Lexer.swift index 1093960..f6473e9 100644 --- a/Sources/Lexer.swift +++ b/Sources/Lexer.swift @@ -50,6 +50,8 @@ enum TokenType: String { case and = "And" case or = "Or" case not = "Not" + case macro = "Macro" + case endMacro = "EndMacro" } struct Token: Equatable { @@ -70,6 +72,8 @@ let keywords: [String: TokenType] = [ "and": .and, "or": .or, "not": .not, + "macro": .macro, + "endmacro": .endMacro, // Literals "true": .booleanLiteral, "false": .booleanLiteral, @@ -81,7 +85,7 @@ func isWord(char: String) -> Bool { } func isInteger(char: String) -> Bool { - char.range(of: #"[0-9]"#, options: .regularExpression) != nil + char.range(of: #"^[0-9]+$"#, options: .regularExpression) != nil } func isWhile(char: String) -> Bool { diff --git a/Sources/Parser.swift b/Sources/Parser.swift index 648a025..b52b798 100644 --- a/Sources/Parser.swift +++ b/Sources/Parser.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections func parse(tokens: [Token]) throws -> Program { var program = Program() @@ -22,34 +23,34 @@ func parse(tokens: [Token]) throws -> Program { return prev } - func parseArgumentsList() throws -> [Statement] { + func parseArgumentsList() throws -> [Expression] { var args: [Expression] = [] while !typeof(.closeParen) { var argument = try parseExpression() if typeof(.equals) { - current += 1 + current += 1 // consume equals if let identifier = argument as? Identifier { let value = try parseExpression() - argument = KeywordArgumentExpression(key: identifier, value: value as! Expression) + argument = KeywordArgumentExpression(key: identifier, value: value) } else { throw JinjaError.syntax("Expected identifier for keyword argument") } } - args.append(argument as! Expression) + args.append(argument) if typeof(.comma) { - current += 1 + current += 1 // consume comma } } return args } - func parseArgs() throws -> [Statement] { + func parseArgs() throws -> [Expression] { try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") let args = try parseArgumentsList() @@ -63,14 +64,10 @@ func parse(tokens: [Token]) throws -> Program { try StringLiteral(value: expect(type: .text, error: "Expected text token").value) } - func parseCallExpression(callee: Statement) throws -> CallExpression { - var args: [Expression] = [] + func parseCallExpression(callee: Expression) throws -> CallExpression { + let args = try parseArgs() - for arg in try parseArgs() { - args.append(arg as! Expression) - } - - var callExpression = CallExpression(callee: callee as! Expression, args: args) + var callExpression = CallExpression(callee: callee, args: args) if typeof(.openParen) { callExpression = try parseCallExpression(callee: callExpression) @@ -79,19 +76,19 @@ func parse(tokens: [Token]) throws -> Program { return callExpression } - func parseMemberExpressionArgumentsList() throws -> Statement { - var slices: [Statement?] = [] + func parseMemberExpressionArgumentsList() throws -> Expression { + var slices: [Expression?] = [] var isSlice = false while !typeof(.closeSquareBracket) { if typeof(.colon) { slices.append(nil) - current += 1 + current += 1 // consume colon isSlice = true } else { - try slices.append(parseExpression()) + slices.append(try parseExpression()) if typeof(.colon) { - current += 1 + current += 1 // consume colon isSlice = true } } @@ -105,24 +102,23 @@ func parse(tokens: [Token]) throws -> Program { if slices.count > 3 { throw JinjaError.syntax("Expected 0-3 arguments for slice expression") } - return SliceExpression( - start: slices[0] as? Expression, - stop: slices.count > 1 ? slices[1] as? Expression : nil, - step: slices.count > 2 ? slices[2] as? Expression : nil + start: slices[0], + stop: slices.count > 1 ? slices[1] : nil, + step: slices.count > 2 ? slices[2] : nil ) } - return slices[0]! + return slices[0]! // normal member expression } - func parseMemberExpression() throws -> Statement { + func parseMemberExpression() throws -> Expression { var object = try parsePrimaryExpression() while typeof(.dot) || typeof(.openSquareBracket) { let operation = tokens[current] current += 1 - var property: Statement + var property: Expression let computed = operation.type != .dot @@ -137,8 +133,8 @@ func parse(tokens: [Token]) throws -> Program { } object = MemberExpression( - object: object as! Expression, - property: property as! Expression, + object: object, + property: property, computed: computed ) } @@ -146,7 +142,7 @@ func parse(tokens: [Token]) throws -> Program { return object } - func parseCallMemberExpression() throws -> Statement { + func parseCallMemberExpression() throws -> Expression { let member = try parseMemberExpression() if typeof(.openParen) { @@ -156,29 +152,33 @@ func parse(tokens: [Token]) throws -> Program { return member } - func parseFilterExpression() throws -> Statement { + func parseFilterExpression() throws -> Expression { var operand = try parseCallMemberExpression() while typeof(.pipe) { - current += 1 + current += 1 // consume pipe var filter = try parsePrimaryExpression() + if !(filter is Identifier) { - throw JinjaError.syntax("Expected identifier for the test") + throw JinjaError.syntax("Expected identifier for the filter") } if typeof(.openParen) { + // Handle filter with arguments filter = try parseCallExpression(callee: filter) } if let filter = filter as? Filter { - operand = FilterExpression(operand: operand as! Expression, filter: filter) + operand = FilterExpression(operand: operand, filter: filter) + } else { + throw JinjaError.syntax("Invalid filter type") } } return operand } - func parseTestExpression() throws -> Statement { + func parseTestExpression() throws -> Expression { var operand = try parseFilterExpression() while typeof(.is) { @@ -194,7 +194,7 @@ func parse(tokens: [Token]) throws -> Program { filter = Identifier(value: "none") } if let test = filter as? Identifier { - operand = TestExpression(operand: operand as! Expression, negate: negate, test: test) + operand = TestExpression(operand: operand, negate: negate, test: test) } else { throw JinjaError.syntax("Expected identifier for the test") } @@ -202,49 +202,49 @@ func parse(tokens: [Token]) throws -> Program { return operand } - func parseMultiplicativeExpression() throws -> Statement { + func parseMultiplicativeExpression() throws -> Expression { var left = try parseTestExpression() while typeof(.multiplicativeBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseTestExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseAdditiveExpression() throws -> Statement { + func parseAdditiveExpression() throws -> Expression { var left = try parseMultiplicativeExpression() while typeof(.additiveBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseMultiplicativeExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseComparisonExpression() throws -> Statement { + func parseComparisonExpression() throws -> Expression { var left = try parseAdditiveExpression() while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) { let operation = tokens[current] current += 1 let right = try parseAdditiveExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseLogicalNegationExpression() throws -> Statement { + func parseLogicalNegationExpression() throws -> Expression { var right: UnaryExpression? while typeof(.not) { let operation = tokens[current] current += 1 let argument = try parseLogicalNegationExpression() - right = UnaryExpression(operation: operation, argument: argument as! Expression) + right = UnaryExpression(operation: operation, argument: argument) } if let right { @@ -254,44 +254,52 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseLogicalAndExpression() throws -> Statement { + func parseLogicalAndExpression() throws -> Expression { var left = try parseLogicalNegationExpression() while typeof(.and) { let operation = tokens[current] current += 1 let right = try parseLogicalNegationExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseLogicalOrExpression() throws -> Statement { + func parseLogicalOrExpression() throws -> Expression { var left = try parseLogicalAndExpression() while typeof(.or) { let operation = tokens[current] current += 1 let right = try parseLogicalAndExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseTernaryExpression() throws -> Statement { + func parseTernaryExpression() throws -> Expression { let a = try parseLogicalOrExpression() + if typeof(.if) { - current += 1 - let test = try parseLogicalOrExpression() - try expect(type: .else, error: "Expected else token") - let b = try parseLogicalOrExpression() - return If(test: test as! Expression, body: [a], alternate: [b]) + current += 1 // consume if token + let predicate = try parseLogicalOrExpression() + + if typeof(.else) { + // Ternary expression with else + current += 1 // consume else token + let b = try parseLogicalOrExpression() + return If(test: predicate, body: [a], alternate: [b]) + } else { + // Select expression on iterable + return SelectExpression(iterable: a, test: predicate) + } } return a } - func parseExpression() throws -> Statement { + func parseExpression() throws -> Expression { try parseTernaryExpression() } @@ -314,9 +322,11 @@ func parse(tokens: [Token]) throws -> Program { if typeof(.equals) { current += 1 - let value = try parseSetStatement() + // Parse the right-hand side as an expression + let value = try parseExpression() - return Set(assignee: left as! Expression, value: value as! Expression) + // Explicitly cast 'value' to 'Expression' + return Set(assignee: left, value: value) } return left @@ -334,31 +344,37 @@ func parse(tokens: [Token]) throws -> Program { && (tokens[current + 1].type == .elseIf || tokens[current + 1].type == .else || tokens[current + 1].type == .endIf)) { - try body.append(parseAny()) + body.append(try parseAny()) } if tokens[current].type == .openStatement, tokens[current + 1].type != .endIf { current += 1 if typeof(.elseIf) { try expect(type: .elseIf, error: "Expected elseif token") - try alternate.append(parseIfStatement()) + alternate.append(try parseIfStatement()) } else { try expect(type: .else, error: "Expected else token") try expect(type: .closeStatement, error: "Expected closing statement token") while !(tokens[current].type == .openStatement && tokens[current + 1].type == .endIf) { - try alternate.append(parseAny()) + alternate.append(try parseAny()) } } } - return If(test: test as! Expression, body: body, alternate: alternate) + return If(test: test, body: body, alternate: alternate) } - func parsePrimaryExpression() throws -> Statement { + func parsePrimaryExpression() throws -> Expression { let token = tokens[current] switch token.type { case .numericLiteral: current += 1 - return NumericLiteral(value: Int(token.value) ?? 0) + if let intValue = Int(token.value) { + return NumericLiteral(value: intValue) + } else if let doubleValue = Double(token.value) { + return NumericLiteral(value: doubleValue) + } else { + throw JinjaError.parser("Invalid numeric literal: \(token.value)") + } case .stringLiteral: current += 1 return StringLiteral(value: token.value) @@ -383,7 +399,7 @@ func parse(tokens: [Token]) throws -> Program { current += 1 var values: [Expression] = [] while !typeof(.closeSquareBracket) { - try values.append(parseExpression() as! Expression) + try values.append(parseExpression()) if typeof(.comma) { current += 1 } @@ -392,12 +408,20 @@ func parse(tokens: [Token]) throws -> Program { return ArrayLiteral(value: values) case .openCurlyBracket: current += 1 - var values: [(Expression, Expression)] = [] + var values = OrderedDictionary() while !typeof(.closeCurlyBracket) { let key = try parseExpression() try expect(type: .colon, error: "Expected colon between key and value in object literal") let value = try parseExpression() - values.append((key as! Expression, value as! Expression)) + + if let key = key as? StringLiteral { + values[key.value] = value + } else if let key = key as? Identifier { + values[key.value] = value + } else { + throw JinjaError.syntax("Expected string literal or identifier as key in object literal") + } + if typeof(.comma) { current += 1 } @@ -409,18 +433,20 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseExpressionSequence(primary: Bool = false) throws -> Statement { + func parseExpressionSequence(primary: Bool = false) throws -> Expression { let fn = primary ? parsePrimaryExpression : parseExpression - var expressions: [Expression] = try [fn() as! Expression] + var expressions: [Expression] = try [fn()] let isTuple = typeof(.comma) + while isTuple { - current += 1 - try expressions.append(fn() as! Expression) + current += 1 // consume comma + try expressions.append(fn()) if !typeof(.comma) { break } } + // Return either a tuple or single expression return isTuple ? TupleLiteral(value: expressions) : expressions[0] } @@ -439,7 +465,7 @@ func parse(tokens: [Token]) throws -> Program { if !(loopVariable is Identifier || loopVariable is TupleLiteral) { throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + "Expected identifier/tuple for the loop variable, got \(type(of: loopVariable)) instead" ) } @@ -447,39 +473,81 @@ func parse(tokens: [Token]) throws -> Program { let iterable = try parseExpression() + // Handle optional if condition for filtering + var ifCondition: Expression? = nil + if typeof(.if) { + current += 1 // consume if token + ifCondition = try parseExpression() + } + try expect(type: .closeStatement, error: "Expected closing statement token") var body: [Statement] = [] - while not(.openStatement, .endFor) { - try body.append(parseAny()) + var defaultBlock: [Statement] = [] + + while not(.openStatement, .endFor) && not(.openStatement, .else) { + body.append(try parseAny()) } - if let loopVariable = loopVariable as? Loopvar { - return For(loopvar: loopVariable, iterable: iterable as! Expression, body: body) + if typeof(.openStatement, .else) { + current += 1 // consume {% + try expect(type: .else, error: "Expected else token") + try expect(type: .closeStatement, error: "Expected closing statement token") + + while not(.openStatement, .endFor) { + defaultBlock.append(try parseAny()) + } } - throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + return For( + loopvar: loopVariable, + iterable: iterable, + body: body, + defaultBlock: defaultBlock, + ifCondition: ifCondition ) } + func parseMacroStatement() throws -> Macro { + let name = try parsePrimaryExpression() + if !(name is Identifier) { + throw JinjaError.syntax("Expected identifier following macro statement") + } + let args = try parseArgs() + try expect(type: .closeStatement, error: "Expected closing statement token") + + var body: [Statement] = [] + + while not(.openStatement, .endMacro) { + body.append(try parseAny()) + } + + return Macro(name: name as! Identifier, args: args, body: body) + } + func parseJinjaStatement() throws -> Statement { + // Consume {% %} tokens try expect(type: .openStatement, error: "Expected opening statement token") var result: Statement - switch tokens[current].type { case .set: - current += 1 + current += 1 // consume 'set' token result = try parseSetStatement() try expect(type: .closeStatement, error: "Expected closing statement token") case .if: - current += 1 + current += 1 // consume 'if' token result = try parseIfStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endIf, error: "Expected endif token") try expect(type: .closeStatement, error: "Expected %} token") + case .macro: + current += 1 // consume 'macro' token + result = try parseMacroStatement() + try expect(type: .openStatement, error: "Expected {% token") + try expect(type: .endMacro, error: "Expected endmacro token") + try expect(type: .closeStatement, error: "Expected %} token") case .for: - current += 1 + current += 1 // consume 'for' token result = try parseForStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endFor, error: "Expected endfor token") @@ -487,7 +555,6 @@ func parse(tokens: [Token]) throws -> Program { default: throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)") } - return result } diff --git a/Sources/Runtime.swift b/Sources/Runtime.swift index 73a0d48..0f2308a 100644 --- a/Sources/Runtime.swift +++ b/Sources/Runtime.swift @@ -6,11 +6,12 @@ // import Foundation +import OrderedCollections protocol RuntimeValue { - associatedtype T - var value: T { get set } + associatedtype ValueType + var value: ValueType { get } var builtins: [String: any RuntimeValue] { get set } func bool() -> Bool @@ -21,7 +22,12 @@ struct NumericValue: RuntimeValue { var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { - self.value as? Int != 0 + if let intValue = self.value as? Int { + return intValue != 0 + } else if let doubleValue = self.value as? Double { + return doubleValue != 0.0 + } + return false } } @@ -35,7 +41,7 @@ struct BooleanValue: RuntimeValue { } struct NullValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -44,7 +50,7 @@ struct NullValue: RuntimeValue { } struct UndefinedValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -69,51 +75,103 @@ struct ArrayValue: RuntimeValue { } struct TupleValue: RuntimeValue { - var value: ArrayValue + var value: [any RuntimeValue] var builtins: [String: any RuntimeValue] = [:] + init(value: [any RuntimeValue]) { + self.value = value + self.builtins["length"] = FunctionValue(value: { _, _ in + NumericValue(value: value.count) + }) + } + func bool() -> Bool { - self.value.bool() + !self.value.isEmpty } } -struct ObjectValue: RuntimeValue { - var value: [String: any RuntimeValue] - var builtins: [String: any RuntimeValue] = [:] +struct ObjectValue: RuntimeValue, Sequence { + var storage: OrderedDictionary + var builtins: [String: any RuntimeValue] - init(value: [String: any RuntimeValue]) { - self.value = value + var value: [String: any RuntimeValue] { Dictionary(uniqueKeysWithValues: storage.map { ($0, $1) }) } + var orderedKeys: [String] { Array(storage.keys) } + + init(value: [String: any RuntimeValue], keyOrder: [String]? = nil) { + // If keyOrder is provided, use it; otherwise, maintain the original order from the dictionary + let orderedKeys = keyOrder ?? Array(value.keys) + let orderedPairs = orderedKeys.compactMap { key in + value[key].map { (key, $0) } + } + + // Recursively create OrderedDictionary for nested objects + let processedPairs = orderedPairs.map { key, value -> (String, any RuntimeValue) in + if let objectValue = value as? ObjectValue { + // Already an ObjectValue, use it directly + return (key, objectValue) + } else if let dictValue = value.value as? [String: any RuntimeValue] { + // If the value contains a dictionary, convert it to ObjectValue + return (key, ObjectValue(value: dictValue)) + } + return (key, value) + } + + self.storage = OrderedDictionary(uniqueKeysWithValues: processedPairs) self.builtins = [ "get": FunctionValue(value: { args, _ in - if let key = args[0] as? StringValue { - if let value = value.first(where: { $0.0 == key.value }) { - return value as! (any RuntimeValue) - } else if args.count > 1 { - return args[1] - } else { - return NullValue() - } - } else { - throw JinjaError.runtime("Object key must be a string: got \(type(of:args[0]))") + guard let key = args[0] as? StringValue else { + throw JinjaError.runtime("Object key must be a string: got \(type(of: args[0]))") + } + if let value = value[key.value] { + return value + } else if args.count > 1 { + return args[1] } + return NullValue() }), "items": FunctionValue(value: { _, _ in - var items: [ArrayValue] = [] - for (k, v) in value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) + ArrayValue( + value: orderedPairs.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + ) }), ] } + mutating func setValue(key: String, value: any RuntimeValue) { + storage[key] = value + } + func bool() -> Bool { - !self.value.isEmpty + !storage.isEmpty + } + + func makeIterator() -> OrderedDictionary.Iterator { + return storage.makeIterator() + } +} + +extension ObjectValue { + func toJSON(indent: Int? = nil, depth: Int = 0) throws -> String { + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: depth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + // Use orderedKeys to maintain insertion order + let pairs = try orderedKeys.map { key in + guard let value = value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try Jinja.toJSON(value, indent: indent, depth: depth + 1) + return "\"\(key)\": \(jsonValue)" + } + + if indent != nil { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } } } @@ -146,12 +204,18 @@ struct StringValue: RuntimeValue { }), "title": FunctionValue(value: { _, _ in - StringValue(value: value.capitalized) + StringValue(value: value.titleCase()) }), "length": FunctionValue(value: { _, _ in NumericValue(value: value.count) }), + "rstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "\\s+$", with: "", options: .regularExpression)) + }), + "lstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "^\\s+", with: "", options: .regularExpression)) + }), ] } @@ -177,17 +241,20 @@ struct Interpreter { let lastEvaluated = try self.evaluate(statement: statement, environment: environment) if !(lastEvaluated is NullValue), !(lastEvaluated is UndefinedValue) { - if let value = lastEvaluated.value as? String { - result += value + if let stringValue = lastEvaluated as? StringValue { + result += stringValue.value + } else if let numericValue = lastEvaluated as? NumericValue { + result += String(describing: numericValue.value) + } else if let booleanValue = lastEvaluated as? BooleanValue { + result += String(booleanValue.value) + } else if let arrayValue = lastEvaluated as? ArrayValue { + // Convert array to JSON string + result += try toJSON(arrayValue) + } else if let objectValue = lastEvaluated as? ObjectValue { + // Convert object to JSON string + result += try toJSON(objectValue) } else { - switch lastEvaluated.value { - case let value as Int: - result += String(value) - case let value as String: - result += value - default: - throw JinjaError.runtime("Unknown value type:\(type(of: lastEvaluated.value))") - } + throw JinjaError.runtime("Cannot convert to string: \(type(of: lastEvaluated))") } } } @@ -206,26 +273,30 @@ struct Interpreter { try environment.setVariable(name: variableName, value: rhs) } else if let member = node.assignee as? MemberExpression { let object = try self.evaluate(statement: member.object, environment: environment) + guard var objectValue = object as? ObjectValue else { + throw JinjaError.runtime("Cannot assign to member of non-object") + } + guard let property = member.property as? Identifier else { + throw JinjaError.runtime("Cannot assign to member with non-identifier property") + } - if var object = object as? ObjectValue { - if let property = member.property as? Identifier { - object.value[property.value] = rhs - } else { - throw JinjaError.runtime("Cannot assign to member with non-identifier property") - } + // Modify the copy + objectValue.setValue(key: property.value, value: rhs) + + // Update the environment with the modified copy + if let parentIdentifier = member.object as? Identifier { + try environment.setVariable(name: parentIdentifier.value, value: objectValue) } else { - throw JinjaError.runtime("Cannot assign to member of non-object") + throw JinjaError.runtime("Cannot assign to computed member expression") } } else { - throw JinjaError.runtime("Invalid assignee type: \(type(of: node.assignee))") + throw JinjaError.runtime("Invalid LHS inside assignment expression: \(node.assignee)") } - return NullValue() } func evaluateIf(node: If, environment: Environment) throws -> StringValue { let test = try self.evaluate(statement: node.test, environment: environment) - return try self.evaluateBlock(statements: test.bool() ? node.body : node.alternate, environment: environment) } @@ -233,66 +304,235 @@ struct Interpreter { environment.lookupVariable(name: node.value) } - func evaluateFor(node: For, environment: Environment) throws -> any RuntimeValue { + func evaluateFor(node: For, environment: Environment) throws -> StringValue { + // Scope for the for loop let scope = Environment(parent: environment) - let iterable = try self.evaluate(statement: node.iterable, environment: scope) - var result = "" - if let iterable = iterable as? ArrayValue { - for i in 0 ..< iterable.value.count { - let loop: [String: any RuntimeValue] = [ - "index": NumericValue(value: i + 1), - "index0": NumericValue(value: i), - "revindex": NumericValue(value: iterable.value.count - i), - "revindex0": NumericValue(value: iterable.value.count - i - 1), - "first": BooleanValue(value: i == 0), - "last": BooleanValue(value: i == iterable.value.count - 1), - "length": NumericValue(value: iterable.value.count), - "previtem": i > 0 ? iterable.value[i - 1] : UndefinedValue(), - "nextitem": i < iterable.value.count - 1 ? iterable.value[i + 1] : UndefinedValue(), - ] - - try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) - - let current = iterable.value[i] + let test: Expression? + let iterable: any RuntimeValue + if let selectExpression = node.iterable as? SelectExpression { + iterable = try self.evaluate(statement: selectExpression.iterable, environment: scope) + test = selectExpression.test + } else { + iterable = try self.evaluate(statement: node.iterable, environment: scope) + test = nil + } + + var items: [any RuntimeValue] = [] + var scopeUpdateFunctions: [(Environment) throws -> Void] = [] + + // Keep track of the indices of the original iterable that passed the test + var filteredIndices: [Int] = [] + var originalIndex = 0 + + // Handle ArrayValue + if let arrayIterable = iterable as? ArrayValue { + for current in arrayIterable.value { + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void if let identifier = node.loopvar as? Identifier { - try scope.setVariable(name: identifier.value, value: current) - } else { - } + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + guard let currentArray = current as? ArrayValue else { + throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of: current))") + } - switch node.loopvar { - case let identifier as Identifier: - try scope.setVariable(name: identifier.value, value: current) - case let tupleLiteral as TupleLiteral: - if let current = current as? ArrayValue { - if tupleLiteral.value.count != current.value.count { - throw JinjaError.runtime( - "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" - ) - } + if tupleLiteral.value.count != currentArray.value.count { + throw JinjaError.runtime( + "Too \(tupleLiteral.value.count > currentArray.value.count ? "few" : "many") items to unpack" + ) + } - for j in 0 ..< tupleLiteral.value.count { - if let identifier = tupleLiteral.value[j] as? Identifier { - try scope.setVariable(name: identifier.value, value: current.value[j]) - } else { - throw JinjaError.runtime( - "Cannot unpack non-identifier type: \(type(of:tupleLiteral.value[j]))" - ) + scopeUpdateFunction = { scope in + for (i, value) in tupleLiteral.value.enumerated() { + guard let identifier = value as? Identifier else { + throw JinjaError.runtime("Cannot unpack non-identifier type: \(type(of: value))") } + try scope.setVariable(name: identifier.value, value: currentArray.value[i]) } - } else { - throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") } - default: - throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle StringValue as a special case + } else if let stringIterable = iterable as? StringValue { + // Treat the string as an iterable of characters + for char in stringIterable.value { + let current = StringValue(value: String(char)) + let loopScope = Environment(parent: scope) + + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle ObjectValue (dictionary) + } else if let objectIterable = iterable as? ObjectValue { + // Treat the dictionary as an iterable of key-value pairs + for (key, value) in objectIterable { + let current = ArrayValue(value: [StringValue(value: key), value]) + let loopScope = Environment(parent: scope) + + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + // Support unpacking of key-value pairs into two variables + if tupleLiteral.value.count != 2 { + throw JinjaError.runtime( + "Cannot unpack dictionary entry: expected 2 variables, got \(tupleLiteral.value.count)" + ) + } + guard let keyIdentifier = tupleLiteral.value[0] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[0]))" + ) + } + guard let valueIdentifier = tupleLiteral.value[1] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[1]))" + ) + } + + scopeUpdateFunction = { scope in + try scope.setVariable(name: keyIdentifier.value, value: StringValue(value: key)) + try scope.setVariable(name: valueIdentifier.value, value: value) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") } - let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) - result += evaluated.value + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 } } else { - throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of:iterable))") + throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of: iterable))") + } + + var result = "" + var noIteration = true + + for i in 0 ..< items.count { + // Get the previous and next items that passed the filter + let previousIndex = filteredIndices.firstIndex(of: filteredIndices[i])! - 1 + let nextIndex = filteredIndices.firstIndex(of: filteredIndices[i])! + 1 + + let previtem: any RuntimeValue + if previousIndex >= 0 { + let previousFilteredIndex = filteredIndices[previousIndex] + if let arrayIterable = iterable as? ArrayValue { + previtem = arrayIterable.value[previousFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index( + stringIterable.value.startIndex, + offsetBy: previousFilteredIndex + ) + previtem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[previousFilteredIndex] + previtem = ArrayValue(value: [StringValue(value: key), value]) + } else { + previtem = UndefinedValue() + } + } else { + previtem = UndefinedValue() + } + + let nextitem: any RuntimeValue + if nextIndex < filteredIndices.count { + let nextFilteredIndex = filteredIndices[nextIndex] + if let arrayIterable = iterable as? ArrayValue { + nextitem = arrayIterable.value[nextFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index(stringIterable.value.startIndex, offsetBy: nextFilteredIndex) + nextitem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[nextFilteredIndex] + nextitem = ArrayValue(value: [StringValue(value: key), value]) + } else { + nextitem = UndefinedValue() + } + } else { + nextitem = UndefinedValue() + } + + let loop: [String: any RuntimeValue] = [ + "index": NumericValue(value: i + 1), + "index0": NumericValue(value: i), + "revindex": NumericValue(value: items.count - i), + "revindex0": NumericValue(value: items.count - i - 1), + "first": BooleanValue(value: i == 0), + "last": BooleanValue(value: i == items.count - 1), + "length": NumericValue(value: items.count), + "previtem": previtem, + "nextitem": nextitem, + ] + + try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) + + try scopeUpdateFunctions[i](scope) + + let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) + result += evaluated.value + + noIteration = false + } + + if noIteration { + let defaultEvaluated = try self.evaluateBlock(statements: node.defaultBlock, environment: scope) + result += defaultEvaluated.value } return StringValue(value: result) @@ -302,31 +542,72 @@ struct Interpreter { let left = try self.evaluate(statement: node.left, environment: environment) if node.operation.value == "and" { - return left.bool() ? try self.evaluate(statement: node.right, environment: environment) : left + if !left.bool() { + return left + } + let right = try self.evaluate(statement: node.right, environment: environment) + return right } else if node.operation.value == "or" { return left.bool() ? left : try self.evaluate(statement: node.right, environment: environment) } let right = try self.evaluate(statement: node.right, environment: environment) + // == if node.operation.value == "==" { - switch left.value { - case let value as String: - return BooleanValue(value: value == right.value as! String) - case let value as Int: - return BooleanValue(value: value == right.value as! Int) - case let value as Bool: - return BooleanValue(value: value == right.value as! Bool) - default: - throw JinjaError.runtime( - "Unknown left value type:\(type(of: left.value)), right value type:\(type(of: right.value))" - ) + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value == right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt == rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble == rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) == rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble == Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for equality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value == right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: true) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: true) + } else if type(of: left) == type(of: right) { + return BooleanValue(value: false) + } else { + return BooleanValue(value: false) } - } else if node.operation.value == "!=" { - if type(of: left) != type(of: right) { + } + + // != + if node.operation.value == "!=" { + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value != right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt != rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble != rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) != rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble != Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for inequality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value != right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: false) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: false) + } else if type(of: left) == type(of: right) { return BooleanValue(value: true) } else { - return BooleanValue(value: left.value as! AnyHashable != right.value as! AnyHashable) + return BooleanValue(value: true) } } @@ -336,92 +617,230 @@ struct Interpreter { throw JinjaError.runtime("Cannot perform operation on null values") } else if let left = left as? NumericValue, let right = right as? NumericValue { switch node.operation.value { - case "+": throw JinjaError.syntaxNotSupported("+") - case "-": throw JinjaError.syntaxNotSupported("-") - case "*": throw JinjaError.syntaxNotSupported("*") - case "/": throw JinjaError.syntaxNotSupported("/") + case "+": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt + rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble + rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) + rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble + Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for addition") + } + case "-": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt - rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble - rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) - rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble - Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for subtraction") + } + case "*": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt * rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble * rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) * rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble * Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for multiplication") + } + case "/": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt / rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble / rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) / rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble / Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for division") + } case "%": - switch left.value { - case is Int: - return NumericValue(value: left.value as! Int % (right.value as! Int)) - default: - throw JinjaError.runtime("Unknown value type:\(type(of: left.value))") + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt % rightInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for modulus") + } + case "<": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt < rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble < rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) < rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble < Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than comparison") + } + case ">": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt > rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble > rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) > rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble > Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than comparison") + } + case ">=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt >= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble >= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) >= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble >= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than or equal to comparison") + } + case "<=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt <= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble <= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) <= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble <= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than or equal to comparison") } - case "<": throw JinjaError.syntaxNotSupported("<") - case ">": throw JinjaError.syntaxNotSupported(">") - case ">=": throw JinjaError.syntaxNotSupported(">=") - case "<=": throw JinjaError.syntaxNotSupported("<=") default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if left is ArrayValue && right is ArrayValue { + } else if let left = left as? ArrayValue, let right = right as? ArrayValue { switch node.operation.value { - case "+": break + case "+": + return ArrayValue(value: left.value + right.value) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if right is ArrayValue { - throw JinjaError.syntaxNotSupported("right is ArrayValue") - } - - if left is StringValue || right is StringValue { - switch node.operation.value { - case "+": - var rightValue = "" - var leftValue = "" - switch right.value { - case let value as String: - rightValue = value - case let value as Int: - rightValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown right value type:\(type(of: right.value))") + } else if let right = right as? ArrayValue { + let member: Bool + if let left = left as? StringValue { + member = right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false } - - switch left.value { - case let value as String: - leftValue = value - case let value as Int: - leftValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown left value type:\(type(of: left.value))") + } else if let left = left as? NumericValue { + member = right.value.contains { + if let item = $0 as? NumericValue { + return item.value as! Int == left.value as! Int + } + return false } - - return StringValue(value: leftValue + rightValue) - default: - break + } else if let left = left as? BooleanValue { + member = right.value.contains { + if let item = $0 as? BooleanValue { + return item.value == left.value + } + return false + } + } else { + throw JinjaError.runtime("Unsupported left type for 'in'/'not in' operation with ArrayValue") } - } - if let left = left as? StringValue, let right = right as? StringValue { switch node.operation.value { case "in": - return BooleanValue(value: right.value.contains(left.value)) + return BooleanValue(value: member) case "not in": - return BooleanValue(value: !right.value.contains(left.value)) + return BooleanValue(value: !member) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } } - if left is StringValue, right is ObjectValue { + if let left = left as? StringValue { switch node.operation.value { + case "+": + let rightValue: String + if let rightString = right as? StringValue { + rightValue = rightString.value + } else if let rightNumeric = right as? NumericValue { + rightValue = String(describing: rightNumeric.value) + } else if let rightBoolean = right as? BooleanValue { + rightValue = String(rightBoolean.value) + } else if right is UndefinedValue { + rightValue = "" + } else { + throw JinjaError.runtime("Unsupported right operand type for string concatenation") + } + return StringValue(value: left.value + rightValue) case "in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime("Right operand of 'in' must be a StringValue, ArrayValue, or ObjectValue") } case "not in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: !rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: !right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: !right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: !right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime( + "Right operand of 'not in' must be a StringValue, ArrayValue, or ObjectValue" + ) + } + default: + break + } + } else if let right = right as? StringValue { + if node.operation.value == "+" { + if let leftString = left as? StringValue { + return StringValue(value: leftString.value + right.value) + } else if let leftNumeric = left as? NumericValue { + return StringValue(value: String(describing: leftNumeric.value) + right.value) + } else if let leftBoolean = left as? BooleanValue { + return StringValue(value: String(leftBoolean.value) + right.value) + } else { + throw JinjaError.runtime("Unsupported left operand type for string concatenation") } + } + } + + if let left = left as? StringValue, let right = right as? ObjectValue { + switch node.operation.value { + case "in": + return BooleanValue(value: right.value.keys.contains(left.value)) + case "not in": + return BooleanValue(value: !right.value.keys.contains(left.value)) default: throw JinjaError.runtime( "Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue" @@ -463,19 +882,19 @@ struct Interpreter { return ArrayValue( value: slice( object.value, - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int ) ) } else if let object = object as? StringValue { return StringValue( value: slice( - Array(arrayLiteral: object.value), - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int - ).joined() + Array(object.value), + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int + ).map { String($0) }.joined() ) } @@ -484,7 +903,6 @@ struct Interpreter { func evaluateMemberExpression(expr: MemberExpression, environment: Environment) throws -> any RuntimeValue { let object = try self.evaluate(statement: expr.object, environment: environment) - var property: any RuntimeValue if expr.computed { if let property = expr.property as? SliceExpression { @@ -495,7 +913,6 @@ struct Interpreter { } else { property = StringValue(value: (expr.property as! Identifier).value) } - var value: (any RuntimeValue)? if let object = object as? ObjectValue { if let property = property as? StringValue { @@ -503,34 +920,55 @@ struct Interpreter { } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } - } else if object is ArrayValue || object is StringValue { + } else if let object = object as? ArrayValue { if let property = property as? NumericValue { - if let object = object as? ArrayValue { - let index = property.value as! Int - if index >= 0 { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { value = object.value[index] - } else { + } else if index < 0 && index >= -object.value.count { value = object.value[object.value.count + index] + } else { + value = UndefinedValue() + } + } else { + throw JinjaError.runtime("Array index must be an integer") + } + } else if let property = property as? StringValue { + value = object.builtins[property.value] + } else { + throw JinjaError.runtime( + "Cannot access property with non-string/non-number: got \(type(of: property))" + ) + } + } else if let object = object as? StringValue { + if let property = property as? NumericValue { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: index) + value = StringValue(value: String(object.value[strIndex])) + } else if index < 0 && index >= -object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: object.value.count + index) + value = StringValue(value: String(object.value[strIndex])) + } else { + value = UndefinedValue() } - } else if let object = object as? StringValue { - let index = object.value.index(object.value.startIndex, offsetBy: property.value as! Int) - value = StringValue(value: String(object.value[index])) + } else { + throw JinjaError.runtime("String index must be an integer") } } else if let property = property as? StringValue { value = object.builtins[property.value] } else { throw JinjaError.runtime( - "Cannot access property with non-string/non-number: got \(type(of:property))" + "Cannot access property with non-string/non-number: got \(type(of: property))" ) } } else { if let property = property as? StringValue { - value = object.builtins[property.value]! + value = object.builtins[property.value] } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } } - if let value { return value } else { @@ -561,7 +999,7 @@ struct Interpreter { } } - if kwargs.count > 0 { + if !kwargs.isEmpty { args.append(ObjectValue(value: kwargs)) } @@ -575,9 +1013,11 @@ struct Interpreter { } func evaluateFilterExpression(node: FilterExpression, environment: Environment) throws -> any RuntimeValue { - let operand = try evaluate(statement: node.operand, environment: environment) - + let operand = try self.evaluate(statement: node.operand, environment: environment) if let identifier = node.filter as? Identifier { + if identifier.value == "tojson" { + return try StringValue(value: toJSON(operand)) + } if let arrayValue = operand as? ArrayValue { switch identifier.value { case "list": @@ -591,7 +1031,32 @@ struct Interpreter { case "reverse": return ArrayValue(value: arrayValue.value.reversed()) case "sort": - throw JinjaError.todo("TODO: ArrayValue filter sort") + return ArrayValue( + value: try arrayValue.value.sorted { + // No need to cast to AnyComparable here + if let a = $0 as? NumericValue, let b = $1 as? NumericValue { + if let aInt = a.value as? Int, let bInt = b.value as? Int { + return aInt < bInt + } else if let aDouble = a.value as? Double, let bDouble = b.value as? Double { + return aDouble < bDouble + } else if let aInt = a.value as? Int, let bDouble = b.value as? Double { + return Double(aInt) < bDouble + } else if let aDouble = a.value as? Double, let bInt = b.value as? Int { + return aDouble < Double(bInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for comparison") + } + } else if let a = $0 as? StringValue, let b = $1 as? StringValue { + return a.value < b.value + } else { + throw JinjaError.runtime( + "Cannot compare values of different types or non-comparable types" + ) + } + } + ) + case "map": + throw JinjaError.todo("TODO: ArrayValue filter map") default: throw JinjaError.runtime("Unknown ArrayValue filter: \(identifier.value)") } @@ -604,34 +1069,38 @@ struct Interpreter { case "lower": return StringValue(value: stringValue.value.lowercased()) case "title": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.titleCase()) case "capitalize": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.prefix(1).uppercased() + stringValue.value.dropFirst()) case "trim": return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) + case "indent": + return StringValue(value: stringValue.value.indent(4)) + case "string": + return stringValue default: throw JinjaError.runtime("Unknown StringValue filter: \(identifier.value)") } } else if let numericValue = operand as? NumericValue { switch identifier.value { case "abs": - return NumericValue(value: abs(numericValue.value as! Int32)) + if let intValue = numericValue.value as? Int { + return NumericValue(value: abs(intValue)) + } else if let doubleValue = numericValue.value as? Double { + return NumericValue(value: abs(doubleValue)) + } else { + throw JinjaError.runtime("Unsupported numeric type for abs filter") + } default: throw JinjaError.runtime("Unknown NumericValue filter: \(identifier.value)") } } else if let objectValue = operand as? ObjectValue { switch identifier.value { case "items": - var items: [ArrayValue] = [] - for (k, v) in objectValue.value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) + let items: [ArrayValue] = objectValue.value.map { (key, value) in + return ArrayValue(value: [StringValue(value: key), value]) } - return items as! (any RuntimeValue) + return ArrayValue(value: items) case "length": return NumericValue(value: objectValue.value.count) default: @@ -639,9 +1108,132 @@ struct Interpreter { } } - throw JinjaError.runtime("Cannot apply filter \(operand.value) to type: \(type(of:operand))") - } + throw JinjaError.runtime("Cannot apply filter \(identifier.value) to type: \(type(of: operand))") + } else if let callExpression = node.filter as? CallExpression { + if let identifier = callExpression.callee as? Identifier { + let filterName = identifier.value + + if filterName == "tojson" { + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let indent = args.1["indent"] ?? NullValue() + + if let indentNumeric = indent as? NumericValue { + if let indentInt = indentNumeric.value as? Int { + return try StringValue(value: toJSON(operand, indent: indentInt)) + } else if let indentDouble = indentNumeric.value as? Double { + return try StringValue(value: toJSON(operand, indent: Int(indentDouble))) + } else { + throw JinjaError.runtime("If set, indent must be a number") + } + } else if indent is NullValue { + return try StringValue(value: toJSON(operand)) + } else { + throw JinjaError.runtime("If set, indent must be a number") + } + } + if let arrayValue = operand as? ArrayValue { + switch filterName { + case "selectattr", "rejectattr": + let select = filterName == "selectattr" + if arrayValue.value.contains(where: { !($0 is ObjectValue) }) { + throw JinjaError.runtime("`\(filterName)` can only be applied to array of objects") + } + if callExpression.args.contains(where: { !($0 is StringLiteral) }) { + throw JinjaError.runtime("arguments of `\(filterName)` must be strings") + } + let args = try callExpression.args.map { arg -> StringValue in + let evaluatedArg = try self.evaluate(statement: arg, environment: environment) + guard let stringValue = evaluatedArg as? StringValue else { + throw JinjaError.runtime("Arguments of `\(filterName)` must be strings") + } + return stringValue + } + let attr = args[0] + let testName = args.count > 1 ? args[1] : nil + let value = args.count > 2 ? args[2] : nil + var testFunction: ((any RuntimeValue, StringValue?) throws -> Bool) + if let testName = testName { + guard let test = environment.tests[testName.value] else { + throw JinjaError.runtime("Unknown test: \(testName.value)") + } + testFunction = { a, b in + try test(a, b ?? UndefinedValue()) + } + } else { + testFunction = { a, _ in + a.bool() + } + } + let filtered = (arrayValue.value as! [ObjectValue]).filter { item in + let a = item.value[attr.value] + let result = a != nil ? try! testFunction(a!, value) : false + return select ? result : !result + } + return ArrayValue(value: filtered) + case "map": + let evaluatedArgs = try self.evaluateArguments( + args: callExpression.args, + environment: environment + ) + let kwargs = evaluatedArgs.1 + if let attribute = kwargs["attribute"] { + let defaultValue = kwargs["default"] + let mapped = try arrayValue.value.map { item -> Any in + guard let objectValue = item as? ObjectValue else { + throw JinjaError.runtime("Items in map must be objects") + } + if let attributeString = attribute as? StringValue { + let result = + objectValue.value[attributeString.value] ?? defaultValue ?? UndefinedValue() + return result + } else { + throw JinjaError.runtime("`map` filter attribute must be a string") + } + } + return ArrayValue(value: mapped.map { $0 as! (any RuntimeValue) }) + } else { + // TODO: Implement map filter without attribute argument + // This will likely involve applying a filter function to each element. + throw JinjaError.runtime("`map` filter without `attribute` is not yet supported.") + } + default: + throw JinjaError.runtime("Unknown ArrayValue filter: \(filterName)") + } + } else if let stringValue = operand as? StringValue { + switch filterName { + case "indent": + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let positionalArgs = args.0 + let kwargs = args.1 + let width = positionalArgs.first ?? kwargs["width"] ?? NumericValue(value: 4) + if !(width is NumericValue) { + throw JinjaError.runtime("width must be a number") + } + let first = + positionalArgs.count > 1 ? positionalArgs[1] : kwargs["first"] ?? BooleanValue(value: false) + let blank = + positionalArgs.count > 2 ? positionalArgs[2] : kwargs["blank"] ?? BooleanValue(value: false) + guard let widthInt = (width as? NumericValue)?.value as? Int else { + throw JinjaError.runtime("width must be an integer") + } + return StringValue( + value: stringValue.value.indent( + widthInt, + first: first.bool(), + blank: blank.bool() + ) + ) + default: + throw JinjaError.runtime("Unknown StringValue filter: \(filterName)") + } + } else { + throw JinjaError.runtime("Cannot apply filter '\(filterName)' to type: \(type(of: operand))") + } + } else { + throw JinjaError.runtime("Unknown filter: \(callExpression.callee)") + } + } throw JinjaError.runtime("Unknown filter: \(node.filter)") } @@ -656,6 +1248,76 @@ struct Interpreter { } } + func evaluateMacro(node: Macro, environment: Environment) throws -> NullValue { + try environment.setVariable( + name: node.name.value, + value: FunctionValue(value: { args, scope in + let macroScope = Environment(parent: scope) + + var args = args + var kwargs: [String: any RuntimeValue] = [:] + + if let lastArg = args.last, let keywordArgsValue = lastArg as? KeywordArgumentsValue { + kwargs = keywordArgsValue.value + args.removeLast() + } + + for i in 0 ..< node.args.count { + let nodeArg = node.args[i] + let passedArg = args.count > i ? args[i] : nil + + if let identifier = nodeArg as? Identifier { + if passedArg == nil { + if let defaultValue = kwargs[identifier.value] { + try macroScope.setVariable(name: identifier.value, value: defaultValue) + } else { + throw JinjaError.runtime("Missing argument: \(identifier.value)") + } + } else { + try macroScope.setVariable(name: identifier.value, value: passedArg!) + } + } else if let kwarg = nodeArg as? KeywordArgumentExpression { + let value = + try kwargs[kwarg.key.value] + ?? (passedArg ?? (try self.evaluate(statement: kwarg.value, environment: macroScope))) + + try macroScope.setVariable(name: kwarg.key.value, value: value) + } else { + throw JinjaError.runtime("Unknown argument type: \(type(of: nodeArg))") + } + } + + return try self.evaluateBlock(statements: node.body, environment: macroScope) + }) + ) + + return NullValue() + } + + func evaluateArguments( + args: [Expression], + environment: Environment + ) throws -> ([any RuntimeValue], [String: any RuntimeValue]) { + var positionalArguments: [any RuntimeValue] = [] + var keywordArguments: [String: any RuntimeValue] = [:] + + for argument in args { + if let keywordArgument = argument as? KeywordArgumentExpression { + keywordArguments[keywordArgument.key.value] = try self.evaluate( + statement: keywordArgument.value, + environment: environment + ) + } else { + if !keywordArguments.isEmpty { + throw JinjaError.runtime("Positional arguments must come before keyword arguments") + } + positionalArguments.append(try self.evaluate(statement: argument, environment: environment)) + } + } + + return (positionalArguments, keywordArguments) + } + func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue { if let statement { switch statement { @@ -678,7 +1340,13 @@ struct Interpreter { case let statement as UnaryExpression: return try self.evaluateUnaryExpression(node: statement, environment: environment) case let statement as NumericLiteral: - return NumericValue(value: statement.value) + if let intValue = statement.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = statement.value as? Double { + return NumericValue(value: doubleValue) + } else { + throw JinjaError.runtime("Invalid numeric literal value") + } case let statement as CallExpression: return try self.evaluateCallExpression(expr: statement, environment: environment) case let statement as BoolLiteral: @@ -687,6 +1355,22 @@ struct Interpreter { return try self.evaluateFilterExpression(node: statement, environment: environment) case let statement as TestExpression: return try self.evaluateTestExpression(node: statement, environment: environment) + case let statement as ArrayLiteral: + return ArrayValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as TupleLiteral: + return TupleValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as ObjectLiteral: + var mapping: [String: any RuntimeValue] = [:] + for (key, value) in statement.value { + mapping[key] = try self.evaluate(statement: value, environment: environment) + } + return ObjectValue(value: mapping) + case let statement as Macro: + return try self.evaluateMacro(node: statement, environment: environment) case is NullLiteral: return NullValue() default: diff --git a/Sources/Utilities.swift b/Sources/Utilities.swift index c01870b..4048379 100644 --- a/Sources/Utilities.swift +++ b/Sources/Utilities.swift @@ -38,3 +38,99 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) return slicedArray } + +func toJSON(_ input: any RuntimeValue, indent: Int? = 4, depth: Int = 0) throws -> String { + let currentDepth = depth + + switch input { + case is NullValue, is UndefinedValue: + return "null" + + case let value as NumericValue: + return String(describing: value.value) + + case let value as StringValue: + // Properly escape special characters for JSON strings + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + + case let value as BooleanValue: + return value.value ? "true" : "false" + + case let arr as ArrayValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + let core = try arr.value.map { try toJSON($0, indent: indent, depth: currentDepth + 1) } + + if indent != nil { + return "[\(childrenPadding)\(core.joined(separator: ",\(childrenPadding)"))\(basePadding)]" + } else { + return "[\(core.joined(separator: ", "))]" + } + + case let obj as ObjectValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + // Use orderedKeys to maintain insertion order + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON(value, indent: indent, depth: currentDepth + 1) + return "\"\(key)\": \(jsonValue)" + } + + if indent != nil { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } +} + +// Helper function to convert values to JSON strings +func jsonString(_ value: Any) throws -> String { + let data = try JSONSerialization.data(withJSONObject: value) + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert value to JSON string") + } + return string +} + +extension String { + func titleCase() -> String { + return self.components(separatedBy: .whitespacesAndNewlines) + .map { word in + guard let firstChar = word.first else { return "" } + return String(firstChar).uppercased() + word.dropFirst() + } + .joined(separator: " ") + } + + func indent(_ width: Int, first: Bool = false, blank: Bool = false) -> String { + let indentString = String(repeating: " ", count: width) + return self.components(separatedBy: .newlines) + .enumerated() + .map { (index, line) in + if line.isEmpty && !blank { + return line + } + if index == 0 && !first { + return line + } + return indentString + line + } + .joined(separator: "\n") + } +} diff --git a/Tests/ChatTemplateTests.swift b/Tests/ChatTemplateTests.swift deleted file mode 100644 index 4b9ab6b..0000000 --- a/Tests/ChatTemplateTests.swift +++ /dev/null @@ -1,238 +0,0 @@ -// -// ChatTemplateTests.swift -// -// -// Created by John Mai on 2024/3/24. -// - -import XCTest - -@testable import Jinja - -let messages: [[String: String]] = [ - [ - "role": "user", - "content": "Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], -] - -let messagesWithSystem: [[String: String]] = - [ - [ - "role": "system", - "content": "You are a friendly chatbot who always responds in the style of a pirate", - ] - ] + messages - -final class ChatTemplateTests: XCTestCase { - struct Test { - let chatTemplate: String - let data: [String: Any] - let target: String - } - - let defaultTemplates: [Test] = [ - Test( - chatTemplate: - "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": false, - ], - target: - "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" - ), - // facebook/blenderbot-400M-distill - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // facebook/blenderbot_small-90M - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // bigscience/bloom - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" - ), - // EleutherAI/gpt-neox-20b - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // gpt2 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messagesWithSystem, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messages, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": [ - [ - "role": "user", - "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], - ], - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // openai/whisper-large-v3 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" - ), - // THUDM/chatglm3-6b - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" - ), - // google/gemma-2b-it - Test( - chatTemplate: - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - data: [ - "messages": messages - ], - target: - "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" - ), - // Qwen/Qwen2.5-0.5B-Instruct - Test( - chatTemplate: - "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - data: [ - "messages": messages - ], - target: - "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" - ), - ] - - func testDefaultTemplates() throws { - for test in defaultTemplates { - let template = try Template(test.chatTemplate) - let result = try template.render(test.data) - XCTAssertEqual(result.debugDescription, test.target.debugDescription) - } - } -} diff --git a/Tests/InterpreterTests.swift b/Tests/InterpreterTests.swift index d402f84..631d2e6 100644 --- a/Tests/InterpreterTests.swift +++ b/Tests/InterpreterTests.swift @@ -141,17 +141,18 @@ final class InterpreterTests: XCTestCase { for test in tests { let env = Environment() try env.set(name: "True", value: true) - for (key, value) in test.data { try env.set(name: key, value: value) } - let tokens = try tokenize(test.template, options: test.options) let parsed = try parse(tokens: tokens) let interpreter = Interpreter(env: env) - let result = try interpreter.run(program: parsed).value as! String - - XCTAssertEqual(result.debugDescription, test.target.debugDescription) + let result = try interpreter.run(program: parsed) + if let stringResult = result as? StringValue { + XCTAssertEqual(stringResult.value.debugDescription, test.target.debugDescription) + } else { + XCTFail("Expected a StringValue, but got \(type(of: result))") + } } } } diff --git a/Tests/Templates/ChatTemplateTests.swift b/Tests/Templates/ChatTemplateTests.swift new file mode 100644 index 0000000..57c3fcc --- /dev/null +++ b/Tests/Templates/ChatTemplateTests.swift @@ -0,0 +1,901 @@ +// +// ChatTemplateTests.swift +// +// +// Created by John Mai on 2024/3/24. +// + +import XCTest + +@testable import Jinja + +final class ChatTemplateTests: XCTestCase { + let messages: [[String: String]] = [ + [ + "role": "user", + "content": "Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ] + + lazy var messagesWithSystemPrompt: [[String: String]] = + [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ] + ] + messages + + func testGenericChatTemplate() throws { + let chatTemplate = + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": false, + ]) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testGenericChatTemplate failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbot400MDistill() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + + if target != result { + print("::: testFacebookBlenderbot400MDistill failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbotSmall90M() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + + if target != result { + print("::: testFacebookBlenderbotSmall90M failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testBigscienceBloom() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" + + if target != result { + print("::: testBigscienceBloom failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testEleutherAIGptNeox20b() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testEleutherAIGptNeox20b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testGPT2() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testGPT2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer2() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer3() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ], + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testOpenaiWhisperLargeV3() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testOpenaiWhisperLargeV3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat1() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat2() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat3() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTHUDMChatglm36b() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" + + if target != result { + print("::: testTHUDMChatglm36b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testGoogleGemma2bIt() throws { + let chatTemplate = + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" + + if target != result { + print("::: testGoogleGemma2bIt failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen2_5_0_5BInstruct() throws { + let chatTemplate = + "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testQwenQwen2_5_0_5BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptFalse() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt, "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bBetaAddGenerationPromptFalse failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptTrue() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ], + ["role": "user", "content": "How many helicopters can a human eat in one sitting?"], + ], "eos_token": "", "add_generation_prompt": true, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHow many helicopters can a human eat in one sitting?\n<|assistant|>\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bBetaAddGenerationPromptTrue failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bGemmaV0_1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bGemmaV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTheBlokeMistral7BInstructV0_1GPTQ() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testTheBlokeMistral7BInstructV0_1GPTQ failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistralaiMixtral8x7BInstructV0_1() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testMistralaiMixtral8x7BInstructV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testCognitivecomputationsDolphin2_5Mixtral8x7b() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testCognitivecomputationsDolphin2_5Mixtral8x7b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testOpenchatOpenchat3_5_0106() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>" + + if target != result { + print("::: testOpenchatOpenchat3_5_0106 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testUpstageSOLAR10_7BInstructV1_0() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n\n### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!\n\n" + + if target != result { + print("::: testUpstageSOLAR10_7BInstructV1_0 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testCodellamaCodeLlama70bInstructHf() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"; + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n " + + if target != result { + print("::: testCodellamaCodeLlama70bInstructHf failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeciDeciLM7BInstruct() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n### Assistant:\nI'm doing great. How can I help you today?\n### User:\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testDeciDeciLM7BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_72BChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testQwenQwen1_5_72BChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekLlm7bChat() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", + "eos_token": "<|end of sentence|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end of sentence|>User: I'd like to show off how chat templating works!\n\n" + + if target != result { + print("::: testDeepseekAiDeepseekLlm7bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testH2oaiH2oDanube1_8bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "", + ] as [String: Any] + ) + let target = + "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + + if target != result { + print("::: testH2oaiH2oDanube1_8bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testInternlmInternlm2Chat7b() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testInternlmInternlm2Chat7b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTheBlokedeepseekCoder33BInstructAWQ() throws { + let chatTemplate = + "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n" + + if target != result { + print("::: testTheBlokedeepseekCoder33BInstructAWQ failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testEriczzzFalconRw1bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'].strip() }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "<|endoftext|>", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!" + + if target != result { + print("::: testEriczzzFalconRw1bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testAbacusaiSmaug34BV0_1() throws { + let chatTemplate = + "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testAbacusaiSmaug34BV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + + if target != result { + print("::: testMaywellSynatraMixtral8x7B failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekCoder33bInstruct() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", "eos_token": "<|EOT|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testDeepseekAiDeepseekCoder33bInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B_2() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\nYou are a friendly chatbot who always responds in the style of a pirate### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + + if target != result { + print("::: testMaywellSynatraMixtral8x7B_2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMaywellPiVoTMoE() throws { + let chatTemplate = + "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt + ] as [String: Any] + ) + // Note: The duplication of the system prompt is a known bug and is replicated here in the target. + let target = + "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!" + + if target != result { + print("::: testMaywellPiVoTMoE failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistralNemoInstruct2407() throws { + let chatTemplate = + "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + ]) + let target = + "[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST]" + + XCTAssertEqual(result, target) + } + + func testQwen2VLTextOnly() throws { + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + Hello, how are you?<|im_end|> + <|im_start|>assistant + I'm doing great. How can I help you today?<|im_end|> + <|im_start|>user + I'd like to show off how chat templating works!<|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLTextOnly failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +} diff --git a/Tests/Templates/ToolUseTests.swift b/Tests/Templates/ToolUseTests.swift new file mode 100644 index 0000000..4155073 --- /dev/null +++ b/Tests/Templates/ToolUseTests.swift @@ -0,0 +1,592 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 30.12.2024. +// + +import XCTest +import OrderedCollections + +@testable import Jinja + +final class ToolUseTests: XCTestCase { + let messagesWithFunctionCalling: [[String: Any?]] = [ + [ + "role": "assistant", + "content": nil, + "tool_calls": [ + [ + "type": "function", + "function": [ + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Hanoi\"\n}", + ], + ] + ], + ], + [ + "role": "user", + "content": "What's the weather like in Hanoi?", + ], + ] + + // Example adapted from https://huggingface.co/fireworks-ai/firefunction-v1 + let exampleFunctionSpec: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_price") as (String, Any), + ("description", "Get the current stock price") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol, e.g. AAPL, GOOG") as (String, Any), + ]) + ) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "check_word_anagram") as (String, Any), + ("description", "Check if two words are anagrams of each other") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "word1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The first word") as (String, Any), + ]) + ) as (String, Any), + ( + "word2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The second word") as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["word1", "word2"]) as (String, Any), + ]) + ) as (String, Any), + ]), + ] + + lazy var messagesWithFunctionCallingAndSystemPrompt: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "system") as (String, Any), + ("content", "You are a helpful assistant with access to functions. Use them if required.") as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "functions") as (String, Any), + ("content", exampleFunctionSpec) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Hi, can you tell me the current stock price of AAPL?") as (String, Any), + ]), + ] + + let exampleToolJSONSchemas: OrderedDictionary> = OrderedDictionary( + uniqueKeysWithValues: [ + ( + "get_current_weather", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather") as (String, Any), + ("description", "Get the current weather in a given location") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The city and state, e.g. San Francisco, CA") + as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ("description", "The unit to return the temperature in.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location", "unit"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_wind_speed", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_wind_speed") as (String, Any), + ("description", "Get the current wind speed in km/h at a given location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ("description", "The current wind speed at the given location in km/h, as a float.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ]) + + lazy var exampleListOfTools: [OrderedDictionary] = [ + exampleToolJSONSchemas["get_current_temperature_v2"]!, + exampleToolJSONSchemas["get_current_wind_speed"]!, + ] + + // Passes + func testMeetKaiFunctionaryMediumV2_2() throws { + let chatTemplate = """ + {#v2.2#}\n{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == "no" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCalling, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = + """ + <|from|>assistant\n<|recipient|>get_current_weather\n<|content|>{\n "location": "Hanoi"\n}<|stop|>\n<|from|>user\n<|recipient|>all\n<|content|>What's the weather like in Hanoi?\n + """ + + if target != result { + print("::: testMeetKaiFunctionaryMediumV2_2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + // Passes + func testFireworksAIFireFunctionV1() throws { + let chatTemplate = """ + {%- set message_roles = ['SYSTEM', 'FUNCTIONS', 'USER', 'ASSISTANT', 'TOOL'] -%}\n{%- set ns = namespace(seen_non_system=false, messages=messages, content='', functions=[]) -%}\n{{ bos_token }}\n{#- Basic consistency checks -#}\n{%- if not ns.messages -%}\n {{ raise_exception('No messages') }}\n{%- endif -%}\n{%- if ns.messages[0]['role'] | upper != 'SYSTEM' -%}\n {%- set ns.messages = [{'role': 'SYSTEM', 'content': 'You are a helpful assistant with access to functions. Use them if required.'}] + ns.messages -%}\n{%- endif -%}\n{%- if ns.messages | length < 2 or ns.messages[0]['role'] | upper != 'SYSTEM' or ns.messages[1]['role'] | upper != 'FUNCTIONS' -%}\n {{ raise_exception('Expected either "functions" or ["system", "functions"] as the first messages') }}\n{%- endif -%}\n{%- for message in ns.messages -%}\n {%- set role = message['role'] | upper -%}\n {#- Validation -#}\n {%- if role not in message_roles -%}\n {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles + ' are supported.') }}\n {%- endif -%}\n {%- set ns.content = message['content'] if message.get('content') else '' -%}\n {#- Move tool calls inside the content -#}\n {%- if 'tool_calls' in message -%}\n {%- for call in message['tool_calls'] -%}\n {%- set ns.content = ns.content + '{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}' -%}\n {%- endfor -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' and '' not in ns.content -%}\n {%- set ns.content = '' + ns.content -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' -%}\n {%- set ns.content = ns.content + eos_token -%}\n {%- endif -%}\n {{ role }}: {{ ns.content }}{{ '\\n\\n' }}\n{%- endfor -%}\nASSISTANT:{{ ' ' }}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCallingAndSystemPrompt, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = """ + SYSTEM: You are a helpful assistant with access to functions. Use them if required.\n\nFUNCTIONS: [\n {\n "name": "get_stock_price",\n "description": "Get the current stock price",\n "parameters": {\n "type": "object",\n "properties": {\n "symbol": {\n "type": "string",\n "description": "The stock symbol, e.g. AAPL, GOOG"\n }\n },\n "required": [\n "symbol"\n ]\n }\n },\n {\n "name": "check_word_anagram",\n "description": "Check if two words are anagrams of each other",\n "parameters": {\n "type": "object",\n "properties": {\n "word1": {\n "type": "string",\n "description": "The first word"\n },\n "word2": {\n "type": "string",\n "description": "The second word"\n }\n },\n "required": [\n "word1",\n "word2"\n ]\n }\n }\n]\n\nUSER: Hi, can you tell me the current stock price of AAPL?\n\nASSISTANT: + """ + + if target != result { + print("::: testFireworksAIFireFunctionV1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + // Fails because tools are omitted in the output, and the result is indented. + // func testMistral7BInstructV0_3JSONSchema() throws { + // let chatTemplate = + // "{{- bos_token }}\n{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {%- if tools and (message == user_messages[-1]) %}\n {{- ' [AVAILABLE_TOOLS] [' }}\n {%- for tool in tools %}\n\t\t{%- set tool = tool.function %}\n\t\t{{- '{\"type\": \"function\", \"function\": {' }}\n\t\t{%- for key, val in tool|items if key != \"return\" %}\n\t\t {%- if val is string %}\n\t\t\t{{- '\"' + key + '\": \"' + val + '\"' }}\n\t\t {%- else %}\n\t\t\t{{- '\"' + key + '\": ' + val|tojson }}\n\t\t {%- endif %}\n\t\t {%- if not loop.last %}\n\t\t\t{{- \", \" }}\n\t\t {%- endif %}\n\t\t{%- endfor %}\n\t\t{{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- ' [/AVAILABLE_TOOLS]' }}\n {%- endif %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message.tool_calls is defined and message.tool_calls|length > 0 %}\n {{- ' [TOOL_CALLS] [' }}\n {%- for tool_call in message.tool_calls %}\n {{- {\"name\": tool_call.function.name, \"arguments\": tool_call.function.arguments, \"id\": tool_call.id}|tojson }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- '] ' }}\n {{- eos_token }}\n \t{%- elif message.content is defined %}\n\t {{- ' ' + message.content + ' ' + eos_token}}\n {%- endif %}\n {%- elif message['role'] == 'tool' %}\n {{- ' [TOOL_RESULTS] ' }}\n {{- '{\"call_id\": \"' + message.tool_call_id + '\", \"content\": ' + message.content|string + '}' }}\n {{- ' [/TOOL_RESULTS] ' }}\n {%- endif %}\n{%- endfor %}\n" + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "system", + // "content": + // "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.", + // ], + // ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + // [ + // "role": "assistant", + // "tool_calls": [ + // [ + // "id": "abcdef123", + // "type": "function", + // "function": [ + // "name": "get_current_temperature", + // "arguments": ["location": "Paris, France", "unit": "celsius"], + // ], + // ] + // ], + // ], + // ["role": "tool", "tool_call_id": "abcdef123", "name": "get_current_temperature", "content": "22.0"], + // ], + // "tools": exampleListOfTools, + // // "tools_json": "", // TODO: Figure out how to convert the array of OrderedDictionaries to JSON + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = """ + // [AVAILABLE_TOOLS] [{"type": "function", "function": {"name": "get_current_temperature", "description": "Get the current temperature at a location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit to return the temperature in."}}, "required": ["location", "unit"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}}, "required": ["location"]}}}] [/AVAILABLE_TOOLS] [INST] Hey, what\'s the temperature in Paris right now? [/INST] [TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "abcdef123"}] [TOOL_RESULTS] {"call_id": "abcdef123", "content": 22.0} [/TOOL_RESULTS] + // """ + // + // if target != result { + // print("::: testMistral7BInstructV0_3JSONSchema failed.") + // print("::: target:") + // print(target) + // print("::: result:") + // print(result) + // } + // XCTAssertEqual(result, target) + // } + + // Fails because tools are omitted in the output + // func testCISCaiMistral7BInstructV0_3SOTAGGUF() throws { + // let chatTemplate = """ + // {{ bos_token }}{% set ns = namespace(lastuser=-1, system=false, functions=false) %}{% if tools %}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.lastuser = loop.index0 %}{% elif message['role'] == 'system' %}{% set ns.system = message['content'] %}{% endif %}{% endfor %}{% set ns.functions = tools|selectattr('type','eq','function')|map(attribute='function')|list|tojson %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% if loop.index0 == ns.lastuser and ns.functions %}{{ '[AVAILABLE_TOOLS] ' }}{{ ns.functions }}{{ '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST] ' }}{% if loop.index0 == ns.lastuser and ns.system %}{{ ns.system + ' ' }}{% endif %}{{ message['content'] }}{{ '[/INST]' }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS] ' }}{{ dict(call_id=message['tool_call_id'], content=message['content'])|tojson }}{{ '[/TOOL_RESULTS]' }}{% elif message['role'] == 'assistant' %}{% if message['tool_calls'] %}{{ '[TOOL_CALLS] [' }}{% for call in message['tool_calls'] %}{% if call['type'] == 'function' %}{{ dict(id=call['id'], name=call['function']['name'], arguments=call['function']['arguments'])|tojson }}{% endif %}{% if not loop.last %}{{ ', ' }}{% endif %}{% endfor %}{{ ']' }}{% else %}{{ message['content'] }}{% endif %}{{ eos_token }}{% endif %}{% endfor %} + // """ + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "user", + // "content": "What's the weather like in Oslo and Stockholm?", + // ] + // ], + // "tools": [exampleToolJSONSchemas["get_current_temperature_v2"]!], + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = + // """ + // [AVAILABLE_TOOLS] [{"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}][/AVAILABLE_TOOLS][INST] What's the weather like in Oslo and Stockholm?[/INST] + // """ + // + // if target != result { + // print("::: testCISCaiMistral7BInstructV0_3SOTAGGUF failed.") + // print("::: target:") + // print(target) + // print("::: result:") + // print(result) + // } + // XCTAssertEqual(result, target) + // } + + // Passes + func testNousResearchHermes2ProLlama38BJSONSchema() throws { + let chatTemplate = """ + {%- macro json_to_python_type(json_spec) %}\n{%- set basic_type_map = {\n "string": "str",\n "number": "float",\n "integer": "int",\n "boolean": "bool"\n} %}\n\n{%- if basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == "array" %}\n {{- "list[" + json_to_python_type(json_spec|items) + "]"}}\n{%- elif json_spec.type == "object" %}\n {%- if json_spec.additionalProperties is defined %}\n {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}}\n {%- else %}\n {{- "dict" }}\n {%- endif %}\n{%- elif json_spec.type is iterable %}\n {{- "Union[" }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type({"type": t}) }}\n {%- if not loop.last %}\n {{- "," }} \n {%- endif %}\n {%- endfor %}\n {{- "]" }}\n{%- else %}\n {{- "Any" }}\n{%- endif %}\n{%- endmacro %}\n\n\n{{- bos_token }}\n{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }}\n{%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- '{"type": "function", "function": ' }}\n {{- '{"name": ' + tool.name + '", ' }}\n {{- '"description": "' + tool.name + '(' }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- param_name + ": " + json_to_python_type(param_fields) }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- if tool.return is defined %}\n {{- " -> " + json_to_python_type(tool.return) }}\n {%- endif %}\n {{- " - " + tool.description + "\\n\\n" }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {%- if loop.first %}\n {{- " Args:\\n" }}\n {%- endif %}\n {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}\n {%- endfor %}\n {%- if tool.return is defined and tool.return.description is defined %}\n {{- "\\n Returns:\\n " + tool.return.description }}\n {%- endif %}\n {{- '"' }}\n {{- ', "parameters": ' }}\n {%- if tool.parameters.properties | length == 0 %}\n {{- "{}" }}\n {%- else %}\n {{- tool.parameters | tojson}}\n {%- endif %}\n {{- "}" }}\n {%- if not loop.last %}\n {{- "\\n" }}\n {%- endif %}\n{%- endfor %}\n{{- " " }}\n{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\n' }}\n{{- "For each function call return a json object with function name and arguments within XML tags as follows:\n" }}\n{{- "\n" }}\n{{- '{"arguments": , "name": }\n' }}\n{{- '<|im_end|>' }}\n{%- for message in messages %}\n {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == "assistant" %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '{ ' }}\n {%- if tool_call.arguments is defined %}\n {{- '"arguments": ' }}\n {{- tool_call.arguments|tojson }}\n {{- ', '}}\n {%- endif %}\n {{- '"name": "' }}\n {{- tool_call.name }}\n {{- '"}' }}\n {{- '\\n ' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == "tool" %}\n {%- if not message.name is defined %}\n {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }}\n {%- endif %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {{- '{"name": "' }}\n {{- message.name }}\n {{- '", "content": ' }}\n {{- message.content|tojson + '}' }}\n {{- '\\n <|im_end|>\\n' }} \n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Fetch the stock fundamentals data for Tesla (TSLA)") as (String, Any), + ]) + ], + "tools": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_fundamentals") as (String, Any), + ("description", "Get fundamental data for a given stock symbol using yfinance API.") + as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol.") as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "description", + """ + A dictionary containing fundamental data. + + Keys: + - 'symbol': The stock symbol. + - 'company_name': The long name of the company. + - 'sector': The sector to which the company belongs. + - 'industry': The industry to which the company belongs. + - 'market_cap': The market capitalization of the company. + - 'pe_ratio': The forward price-to-earnings ratio. + - 'pb_ratio': The price-to-book ratio. + - 'dividend_yield': The dividend yield. + - 'eps': The trailing earnings per share. + - 'beta': The beta value of the stock. + - '52_week_high': The 52-week high price of the stock. + - '52_week_low': The 52-week low price of the stock. + """ + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API. + + Args: + symbol(str): The stock symbol. + Returns: + A dictionary containing fundamental data. + + Keys: + - 'symbol': The stock symbol. + - 'company_name': The long name of the company. + - 'sector': The sector to which the company belongs. + - 'industry': The industry to which the company belongs. + - 'market_cap': The market capitalization of the company. + - 'pe_ratio': The forward price-to-earnings ratio. + - 'pb_ratio': The price-to-book ratio. + - 'dividend_yield': The dividend yield. + - 'eps': The trailing earnings per share. + - 'beta': The beta value of the stock. + - '52_week_high': The 52-week high price of the stock. + - '52_week_low': The 52-week low price of the stock.", "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock symbol." + } + }, + "required": [ + "symbol" + ] + }} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} + For each function call return a json object with function name and arguments within XML tags as follows: + + {"arguments": , "name": } + <|im_end|><|im_start|>user + Fetch the stock fundamentals data for Tesla (TSLA)<|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testNousResearchHermes2ProLlama38BJSONSchema failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + // Passes + func testMetaLlamaLlama3_18BInstruct() throws { + let chatTemplate = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = "26 Jul 2024" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\\n\\n"}}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- "<|python_tag|>" + tool_call.name + ".call(" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '="' + arg_val + '"' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{"name": "' + tool_call.name + '", ' }}\n {{- '"parameters": ' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- "<|eom_id|>" }}\n {%- else %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + ["role": "system", "content": "You are a bot that responds to weather queries."], + ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + ], + "tools": [exampleToolJSONSchemas["get_current_temperature_v1"]!], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.\n\n{\n "type": "function",\n "function": {\n "name": "get_current_temperature",\n "description": "Get the current temperature at a location.",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The location to get the temperature for, in the format \\"City, Country\\""\n }\n },\n "required": [\n "location"\n ]\n },\n "return": {\n "type": "number",\n "description": "The current temperature at the specified location in the specified units, as a float."\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n + """ + + if target != result { + print("::: testMetaLlamaLlama3_18BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +} + +extension Data { + var string: String? { + return String(data: self, encoding: .utf8) + } +} diff --git a/Tests/Templates/VisionTests.swift b/Tests/Templates/VisionTests.swift new file mode 100644 index 0000000..df2f778 --- /dev/null +++ b/Tests/Templates/VisionTests.swift @@ -0,0 +1,297 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 31.12.2024. +// + +import XCTest +import OrderedCollections + +@testable import Jinja + +final class VisionTests: XCTestCase { + let llama3_2visionChatTemplate = + "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + + func testLlama3_2_11BVisionInstructTextChatOnly() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Hello, how are you?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "assistant", + "content": [ + [ + "type": "text", + "text": "I'm doing great. How can I help you today?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "I'd like to show off how chat templating works!", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "date_string": "26 Jul 2024" as Any, + "tools_in_user_message": true as Any, + "system_message": "You are a helpful assistant." as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + if target != result { + print("::: testLlama3_2_11BVisionInstructTextChatOnly failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithImages() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: Any], + [ + "type": "image", + "image": "base64_encoded_image_data", + ] as [String: Any], + ] as [[String: Any]], + ] as [String: Any] + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat's in this image?<|image|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + if target != result { + print("::: testLlama3_2_11BVisionInstructWithImages failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwen2VLWithImages() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?Picture 1: <|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLWithImages failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwen2VLWithVideo() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's happening in this video?", + ] as [String: String], + [ + "type": "video", + "video_url": "example.mp4", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's happening in this video?Video 1: <|vision_start|><|video_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLWithVideo failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithTools() throws { + let template = try Template(llama3_2visionChatTemplate) + + let tools: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function" as Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather" as Any), + ("description", "Get the current weather in a given location" as Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object" as Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("description", "The city and state, e.g. San Francisco, CA" as Any), + ]) as Any + ), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("enum", ["celsius", "fahrenheit"] as Any), + ]) as Any + ), + ]) as Any + ), + ("required", ["location"] as Any), + ]) as Any + ), + ]) as Any + ), + ]) + ] + + let result = try template.render([ + "messages": [ + [ + "role": "system", + "content": "You are a helpful assistant.", + ], + [ + "role": "user", + "content": "What's the weather like in San Francisco?", + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + "tools": tools as Any, + "tools_in_user_message": true as Any, + ]) + let target = """ + + <|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What's the weather like in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + + if target != result { + print("::: testLlama3_2_11BVisionInstructWithTools failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +}