commit 4af0f4e2b4150d63ea1abd5ae52ee1c74f819eca
parent 4347b2eaffa12ee1582e181142ec4a8a71db0e5a
Author: AlexKnauth <alexander@knauth.org>
Date: Tue, 10 May 2016 18:58:16 -0400
add find-X-variance, covariant-X?, and covariant-Xs?
and also allow type constructors to declare the variance of their arguments.
infer variances for non-recursive `define-type` types
Diffstat:
4 files changed, 193 insertions(+), 3 deletions(-)
diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt
@@ -23,6 +23,9 @@
(require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list)))
(require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup)))
+(module+ test
+ (require (for-syntax rackunit)))
+
(provide → →/test match2 define-type)
;; ML-like language
@@ -252,6 +255,44 @@
(define (inst-types/cs/orig Xs cs tys [var=? free-identifier=?])
(stx-map (lambda (t) (inst-type/cs/orig Xs cs t var=?)) tys))
+ ;; covariant-Xs? : Type -> Bool
+ ;; Takes a possibly polymorphic type, and returns true if all of the
+ ;; type variables are in covariant positions within the type, false
+ ;; otherwise.
+ (define (covariant-Xs? ty)
+ (syntax-parse ((current-type-eval) ty)
+ [(~?∀ Xs ty)
+ (for/and ([X (in-list (syntax->list #'Xs))])
+ (covariant-X? X #'ty))]))
+
+ ;; find-X-variance : Id Type -> Variance
+ ;; Returns the variance of X within the type ty
+ (define (find-X-variance X ty)
+ (syntax-parse ty
+ [A:id #:when (free-identifier=? #'A X) covariant]
+ [(~Any tycons) irrelevant]
+ [(~?∀ () (~Any tycons τ ...))
+ #:when (get-arg-variances #'tycons)
+ #:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons))
+ (for/fold ([acc irrelevant])
+ ([τ (in-list (syntax->list #'[τ ...]))]
+ [arg-variance (in-list (get-arg-variances #'tycons))])
+ (variance-join
+ acc
+ (variance-compose arg-variance (find-X-variance X τ))))]
+ [ty #:when (not (stx-contains-id? #'ty X)) irrelevant]
+ [_ invariant]))
+
+ ;; covariant-X? : Id Type -> Bool
+ ;; Returns true if every place X appears in ty is a covariant position, false otherwise.
+ (define (covariant-X? X ty)
+ (variance-covariant? (find-X-variance X ty)))
+
+ ;; contravariant-X? : Id Type -> Bool
+ ;; Returns true if every place X appears in ty is a contravariant position, false otherwise.
+ (define (contravariant-X? X ty)
+ (variance-contravariant? (find-X-variance X ty)))
+
;; compute unbound tyvars in one unexpanded type ty
(define (compute-tyvar1 ty)
(syntax-parse ty
@@ -386,12 +427,26 @@
#'(StructName ...) #'((fld ...) ...))
#:with (Cons? ...) (stx-map mk-? #'(StructName ...))
#:with (exposed-Cons? ...) (stx-map mk-? #'(Cons ...))
+ #:do [(define expanded-tys
+ (for/list ([τ (in-list (syntax->list #'[τ ... ...]))])
+ (with-handlers ([exn:fail:syntax? (λ (e) #false)])
+ ((current-type-eval) #`(∀ (X ...) #,τ)))))]
+ #:with [arg-variance ...]
+ (for/list ([i (in-range (length (syntax->list #'[X ...])))])
+ (for/fold ([acc irrelevant])
+ ([ty (in-list expanded-tys)])
+ (cond [ty
+ (define/syntax-parse (~?∀ Xs τ) ty)
+ (define X (list-ref (syntax->list #'Xs) i))
+ (variance-join acc (find-X-variance X #'τ))]
+ [else invariant])))
#`(begin
(define-syntax (NameExtraInfo stx)
(syntax-parse stx
[(_ X ...) #'(('Cons 'StructName Cons? [acc τ] ...) ...)]))
(define-type-constructor Name
#:arity = #,(stx-length #'(X ...))
+ #:arg-variances (λ (stx) (list 'arg-variance ...))
#:extra-info 'NameExtraInfo
#:no-provide)
(struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ...
@@ -1290,3 +1345,30 @@
(cond [(eof-object? x) ""]
[(number? x) (number->string x)]
[(symbol? x) (symbol->string x)])) : String)])
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+(module+ test
+ (begin-for-syntax
+ (check-true (covariant-Xs? #'Int))
+ (check-true (covariant-Xs? #'(stlc+box:Ref Int)))
+ (check-true (covariant-Xs? #'(→ Int Int)))
+ (check-true (covariant-Xs? #'(∀ (X) X)))
+ (check-false (covariant-Xs? #'(∀ (X) (stlc+box:Ref X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ X X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ X Int))))
+ (check-true (covariant-Xs? #'(∀ (X) (→ Int X))))
+ (check-true (covariant-Xs? #'(∀ (X) (→ (→ X Int) X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ (→ (→ X Int) Int) X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ (stlc+box:Ref X) Int))))
+ (check-false (covariant-Xs? #'(∀ (X Y) (→ X Y))))
+ (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) Y))))
+ (check-false (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Y Int)))))
+ (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Int Y)))))
+ (check-false (covariant-Xs? #'(∀ (A B) (→ (→ Int (stlc+rec-iso:× A B))
+ (→ String (stlc+rec-iso:× A B))
+ (stlc+rec-iso:× A B)))))
+ (check-true (covariant-Xs? #'(∀ (A B) (→ (→ (stlc+rec-iso:× A B) Int)
+ (→ (stlc+rec-iso:× A B) String)
+ (stlc+rec-iso:× A B)))))
+ ))
diff --git a/tapl/stlc+tup.rkt b/tapl/stlc+tup.rkt
@@ -1,5 +1,7 @@
#lang s-exp "typecheck.rkt"
(extends "ext-stlc.rkt")
+
+(require (for-syntax racket/list))
;; Simply-Typed Lambda Calculus, plus tuples
;; Types:
@@ -9,7 +11,9 @@
;; - terms from ext-stlc.rkt
;; - tup and proj
-(define-type-constructor × #:arity >= 0)
+(define-type-constructor × #:arity >= 0
+ #:arg-variances (λ (stx)
+ (make-list (stx-length (stx-cdr stx)) covariant)))
(define-typed-syntax tup
[(_ e ...)
diff --git a/tapl/stlc.rkt b/tapl/stlc.rkt
@@ -2,6 +2,8 @@
(provide (for-syntax current-type=? types=?))
(provide (for-syntax mk-app-err-msg))
+(require (for-syntax racket/list))
+
;; Simply-Typed Lambda Calculus
;; - no base types; can't write any terms
;; Types: multi-arg → (1+)
@@ -66,7 +68,13 @@
(define-syntax-category type)
-(define-type-constructor → #:arity >= 1)
+(define-type-constructor → #:arity >= 1
+ #:arg-variances (λ (stx)
+ (syntax-parse stx
+ [(_ τ_in ... τ_out)
+ (append
+ (make-list (stx-length #'[τ_in ...]) contravariant)
+ (list covariant))])))
(define-typed-syntax λ
[(_ bvs:type-ctx e)
diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt
@@ -15,6 +15,9 @@
"stx-utils.rkt"))
(for-meta 2 (all-from-out racket/base syntax/parse racket/syntax)))
+(module+ test
+ (require (for-syntax rackunit)))
+
;; type checking functions/forms
;; General type checking strategy:
@@ -420,6 +423,48 @@
(define (brack? stx)
(define paren-shape/#f (syntax-property stx 'paren-shape))
(and paren-shape/#f (char=? paren-shape/#f #\[)))
+
+ (define (iff b1 b2)
+ (boolean=? b1 b2))
+
+ ;; Variance is (variance Boolean Boolean)
+ (struct variance (covariant? contravariant?) #:prefab)
+ (define irrelevant (variance #true #true))
+ (define covariant (variance #true #false))
+ (define contravariant (variance #false #true))
+ (define invariant (variance #false #false))
+ ;; variance-irrelevant? : Variance -> Boolean
+ (define (variance-irrelevant? v)
+ (and (variance-covariant? v) (variance-contravariant? v)))
+ ;; variance-invariant? : Variance -> Boolean
+ (define (variance-invariant? v)
+ (and (not (variance-covariant? v)) (not (variance-contravariant? v))))
+ ;; variance-join : Variance Variance -> Variance
+ (define (variance-join v1 v2)
+ (variance (and (variance-covariant? v1)
+ (variance-covariant? v2))
+ (and (variance-contravariant? v1)
+ (variance-contravariant? v2))))
+ ;; variance-compose : Variance Variance -> Variance
+ (define (variance-compose v1 v2)
+ (variance (or (variance-irrelevant? v1)
+ (variance-irrelevant? v2)
+ (and (variance-covariant? v1) (variance-covariant? v2))
+ (and (variance-contravariant? v1) (variance-contravariant? v2)))
+ (or (variance-irrelevant? v1)
+ (variance-irrelevant? v2)
+ (and (variance-covariant? v1) (variance-contravariant? v2))
+ (and (variance-contravariant? v1) (variance-covariant? v2)))))
+
+ ;; add-arg-variances : Id (Listof Variance) -> Id
+ ;; Takes a type constructor id and adds variance information about the arguments.
+ (define (add-arg-variances id arg-variances)
+ (set-stx-prop/preserved id 'arg-variances arg-variances))
+ ;; get-arg-variances : Id -> (U False (Listof Variance))
+ ;; Takes a type constructor id and gets the argument variance information.
+ (define (get-arg-variances id)
+ (syntax-property id 'arg-variances))
+
;; todo: abstract out the common shape of a type constructor,
;; i.e., the repeated pattern code in the functions below
(define (get-extra-info t)
@@ -482,6 +527,10 @@
#:defaults ([bvs-op #'=][bvs-n #'0]))
(~optional (~seq #:arr (~and (~parse has-annotations? #'#t) tycon))
#:defaults ([tycon #'void]))
+ (~optional (~seq #:arg-variances arg-variances-stx:expr)
+ #:defaults ([arg-variances-stx
+ #`(λ (stx-id) (for/list ([arg (in-list (stx->list (stx-cdr stx-id)))])
+ invariant))]))
(~optional (~seq #:extra-info extra-info)
#:defaults ([extra-info #'void]))
(~optional (~and #:no-provide (~parse no-provide? #'#t))))
@@ -532,6 +581,7 @@
#:msg
"Expected ~a type, got: ~a"
#'τ #'any))))])))
+ (define arg-variances arg-variances-stx)
(define (τ? t)
(and (stx-pair? t)
(syntax-parse t
@@ -565,10 +615,11 @@
#:with k_result (if #,(attribute has-annotations?)
#'(tycon k_arg (... ...))
#'#%kind)
+ #:with τ-internal* (add-arg-variances #'τ-internal (arg-variances stx))
(add-orig
(assign-type
(syntax/loc stx
- (τ-internal (λ bvs- (#%expression extra-info) . τs-)))
+ (τ-internal* (λ bvs- (#%expression extra-info) . τs-)))
#'k_result)
#'(τ . args))]
;; else fail with err msg
@@ -701,3 +752,48 @@
(define (substs τs xs e [cmp bound-identifier=?])
(stx-fold (lambda (ty x res) (subst ty x res cmp)) e τs xs)))
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+(module+ test
+ (begin-for-syntax
+ (test-case "variance-join"
+ (test-case "joining with irrelevant doesn't change it"
+ (check-equal? (variance-join irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-join irrelevant covariant) covariant)
+ (check-equal? (variance-join irrelevant contravariant) contravariant)
+ (check-equal? (variance-join irrelevant invariant) invariant))
+ (test-case "joining with invariant results in invariant"
+ (check-equal? (variance-join invariant irrelevant) invariant)
+ (check-equal? (variance-join invariant covariant) invariant)
+ (check-equal? (variance-join invariant contravariant) invariant)
+ (check-equal? (variance-join invariant invariant) invariant))
+ (test-case "joining a with a results in a"
+ (check-equal? (variance-join irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-join covariant covariant) covariant)
+ (check-equal? (variance-join contravariant contravariant) contravariant)
+ (check-equal? (variance-join invariant invariant) invariant))
+ (test-case "joining covariant with contravariant results in invariant"
+ (check-equal? (variance-join covariant contravariant) invariant)
+ (check-equal? (variance-join contravariant covariant) invariant)))
+ (test-case "variance-compose"
+ (test-case "composing with covariant doesn't change it"
+ (check-equal? (variance-compose covariant irrelevant) irrelevant)
+ (check-equal? (variance-compose covariant covariant) covariant)
+ (check-equal? (variance-compose covariant contravariant) contravariant)
+ (check-equal? (variance-compose covariant invariant) invariant))
+ (test-case "composing with irrelevant results in irrelevant"
+ (check-equal? (variance-compose irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-compose irrelevant covariant) irrelevant)
+ (check-equal? (variance-compose irrelevant contravariant) irrelevant)
+ (check-equal? (variance-compose irrelevant invariant) irrelevant))
+ (test-case "otherwise composing with invariant results in invariant"
+ (check-equal? (variance-compose invariant covariant) invariant)
+ (check-equal? (variance-compose invariant contravariant) invariant)
+ (check-equal? (variance-compose invariant invariant) invariant))
+ (test-case "composing with with contravariant flips covariant and contravariant"
+ (check-equal? (variance-compose contravariant covariant) contravariant)
+ (check-equal? (variance-compose contravariant contravariant) covariant)
+ (check-equal? (variance-compose contravariant irrelevant) irrelevant)
+ (check-equal? (variance-compose contravariant invariant) invariant)))
+ ))